diff --git a/CHANGELOG.md b/CHANGELOG.md index 1426b6d167..bebe7635f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ with the exception that minor releases may include breaking changes. ### Added +- ✨ Add a `fuse-single-qubit-unitary-runs` pass + for fusing compile-time single-qubit unitary runs via Euler resynthesis + ([#1672]) ([**@simon1hofmann**], [**@burgholzer**]) - ✨ Add QIR program format support to the DDSIM QDMI Device ([#1766]) ([**@rturrado**]) - 🚸 Add [CMake presets] to provide a standardized @@ -621,6 +624,7 @@ changelogs._ [#1675]: https://github.com/munich-quantum-toolkit/core/pull/1675 [#1674]: https://github.com/munich-quantum-toolkit/core/pull/1674 [#1673]: https://github.com/munich-quantum-toolkit/core/pull/1673 +[#1672]: https://github.com/munich-quantum-toolkit/core/pull/1672 [#1664]: https://github.com/munich-quantum-toolkit/core/pull/1664 [#1662]: https://github.com/munich-quantum-toolkit/core/pull/1662 [#1660]: https://github.com/munich-quantum-toolkit/core/pull/1660 diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCODialect.h b/mlir/include/mlir/Dialect/QCO/IR/QCODialect.h index 9ec5b5f384..1b08a1715a 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCODialect.h +++ b/mlir/include/mlir/Dialect/QCO/IR/QCODialect.h @@ -10,6 +10,8 @@ #pragma once +#include "mlir/Dialect/Utils/Utils.h" + #include #include #include @@ -127,6 +129,16 @@ template class TargetAndParameterArityTrait { llvm::reportFatalUsageError( "Given qubit is not an input of the operation"); } + + [[nodiscard]] bool hasCompileTimeKnownUnitaryMatrix() { + if constexpr (P == 0) { + return true; + } else { + return llvm::all_of(this->getParameters(), [](Value param) { + return utils::valueToDouble(param).has_value(); + }); + } + } }; }; @@ -151,8 +163,9 @@ inline func::FuncOp getEntryPoint(ModuleOp op) { }; for (auto func : op.getOps()) { - const auto passthrough = func->getAttrOfType(PASSTHROUGH_LABEL); - if (passthrough && llvm::any_of(passthrough, isEntry)) { + if (const auto passthrough = + func->getAttrOfType(PASSTHROUGH_LABEL); + passthrough && llvm::any_of(passthrough, isEntry)) { return func; } } diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.td b/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.td index 854a3af5a8..9450fcab74 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.td @@ -39,6 +39,11 @@ def UnitaryOpInterface : OpInterface<"UnitaryOpInterface"> { return true; } return false; + } else if constexpr (std::is_same_v) { + if (auto matrix = $_op.getUnitaryMatrix()) { + return out.assignFrom(*matrix); + } + return false; } else { return false; } @@ -63,6 +68,11 @@ def UnitaryOpInterface : OpInterface<"UnitaryOpInterface"> { return true; } return false; + } else if constexpr (std::is_same_v) { + if (auto matrix = $_op.getUnitaryMatrix()) { + return out.assignFrom(*matrix); + } + return false; } else { return false; } @@ -87,6 +97,11 @@ def UnitaryOpInterface : OpInterface<"UnitaryOpInterface"> { return true; } return false; + } else if constexpr (std::is_same_v) { + if (auto matrix = $_op.getUnitaryMatrix()) { + return out.assignFrom(*matrix); + } + return false; } else { return false; } @@ -188,6 +203,9 @@ def UnitaryOpInterface : OpInterface<"UnitaryOpInterface"> { "StringRef", "getBaseSymbol", (ins)>, // Unitary matrix helpers + InterfaceMethod<"Returns true if the operation has a compile-time known " + "unitary matrix representation, false otherwise.", + "bool", "hasCompileTimeKnownUnitaryMatrix", (ins)>, InterfaceMethod<"Populates the given 1x1 unitary matrix if possible.", "bool", "getUnitaryMatrix1x1", (ins "Matrix1x1&":$out), unitaryMatrix1x1MethodBody>, diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index 17bab8174b..f6283bf943 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -1030,6 +1030,8 @@ def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface]> { static Value getParameter(size_t i) { llvm::reportFatalUsageError("BarrierOp has no parameters"); } static OperandRange getParameters() { return {nullptr, 0}; } [[nodiscard]] static StringRef getBaseSymbol() { return "barrier"; } + [[nodiscard]] bool hasCompileTimeKnownUnitaryMatrix() const { return true; } + [[nodiscard]] DynamicMatrix getUnitaryMatrix(); }]; let builders = [OpBuilder<(ins "ValueRange":$qubits)>]; @@ -1126,6 +1128,7 @@ def CtrlOp : QCOOp<"ctrl", Value getParameter(size_t i) { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } OperandRange getParameters() { return {nullptr, 0}; } [[nodiscard]] static StringRef getBaseSymbol() { return "ctrl"; } + [[nodiscard]] bool hasCompileTimeKnownUnitaryMatrix(); [[nodiscard]] std::optional getUnitaryMatrix(); }]; @@ -1199,6 +1202,7 @@ def InvOp : QCOOp<"inv", traits = [UnitaryOpInterface, Value getParameter(size_t i) { llvm::reportFatalUsageError("InvOp does not have parameters"); } OperandRange getParameters() { return {nullptr, 0}; } [[nodiscard]] static StringRef getBaseSymbol() { return "inv"; } + [[nodiscard]] bool hasCompileTimeKnownUnitaryMatrix(); [[nodiscard]] std::optional getUnitaryMatrix(); }]; diff --git a/mlir/include/mlir/Dialect/QCO/Transforms/Decomposition/Euler.h b/mlir/include/mlir/Dialect/QCO/Transforms/Decomposition/Euler.h new file mode 100644 index 0000000000..8fb0018b28 --- /dev/null +++ b/mlir/include/mlir/Dialect/QCO/Transforms/Decomposition/Euler.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include "mlir/Dialect/QCO/Utils/Matrix.h" + +#include +#include +#include + +#include +#include +#include + +namespace mlir::qco::decomposition { + +/** + * @brief Native gate sets for single-qubit Euler synthesis. + */ +enum class EulerBasis : std::uint8_t { + ZYZ = 0, ///< `RZ(phi) * RY(theta) * RZ(lambda)`. + ZXZ = 1, ///< `RZ(phi) * RX(theta) * RZ(lambda)`. + XZX = 2, ///< `RX(phi) * RZ(theta) * RX(lambda)`. + XYX = 3, ///< `RX(phi) * RY(theta) * RX(lambda)`. + U = 4, ///< `U(theta, phi, lambda)`. + ZSXX = 5, ///< `RZ` / `SX` / `X` synthesis via ZYZ decomposition. +}; + +/** + * @brief Parses a basis name (e.g. `zyz`, `zsxx`; case-insensitive). + * + * @param basis The basis name. + * @return The parsed basis, or `std::nullopt` if unrecognized. + */ +[[nodiscard]] std::optional parseEulerBasis(StringRef basis); + +/** + * @brief Synthesizes a composed single-qubit unitary as gates in @p basis. + * + * Returns `std::nullopt` when @p hasNonBasisGate is false and resynthesis + * would not shorten a run of @p runSize gates; otherwise emits gates + * (including `qco.gphase` when needed). + * + * @param builder Builder for the emitted operations. + * @param loc Location for the emitted operations. + * @param qubit Input qubit value. + * @param composed Composed unitary to synthesize. + * @param runSize Number of gates in the run. + * @param hasNonBasisGate Whether the run contains a gate outside @p basis. + * @param basis The target Euler basis. + * @return The synthesized qubit, or `std::nullopt` if synthesis is skipped. + */ +[[nodiscard]] std::optional +synthesizeUnitary1QEuler(OpBuilder& builder, Location loc, Value qubit, + const Matrix2x2& composed, std::size_t runSize, + bool hasNonBasisGate, EulerBasis basis); + +} // namespace mlir::qco::decomposition diff --git a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td index 32f678924e..fafb906a77 100644 --- a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td @@ -40,6 +40,32 @@ def MergeSingleQubitRotationGates }]; } +def FuseSingleQubitUnitaryRuns + : Pass<"fuse-single-qubit-unitary-runs", "mlir::ModuleOp"> { + let dependentDialects = ["mlir::qco::QCODialect", + "::mlir::arith::ArithDialect", + "::mlir::qtensor::QTensorDialect"]; + let summary = "Fuse single-qubit unitary runs using Euler resynthesis"; + let description = [{ + Matches maximal runs of consecutive single-qubit unitary operations on the + same qubit wire (anchored at each run head), composes their constant unitary + matrices, and replaces a run with an equivalent sequence of basis gates when + beneficial: when the run contains a gate outside the target `basis`, or when + Euler resynthesis would shorten it (`synthesizeUnitary1QEuler`). Runs that are + already in the target `basis` and no shorter than the canonical synthesis + length are left unchanged. + + The emitted basis is controlled via the `basis` option (e.g. `zyz`, `zsxx`). + A `gphase` correction is inserted when needed so the rewritten sequence + matches the composed matrix exactly (not only up to global phase). + + Currently, only operations whose unitary matrix can be obtained at compile + time are fused. + }]; + let options = [Option<"basis", "basis", "std::string", "\"zyz\"", + "Target Euler basis (zyz, zxz, xzx, xyx, u, zsxx).">]; +} + def QuantumLoopUnroll : InterfacePass<"quantum-loop-unroll", "FunctionOpInterface"> { let dependentDialects = ["mlir::qco::QCODialect", "mlir::scf::SCFDialect"]; diff --git a/mlir/include/mlir/Dialect/QCO/Utils/Matrix.h b/mlir/include/mlir/Dialect/QCO/Utils/Matrix.h index 2305a06346..8f49d372d7 100644 --- a/mlir/include/mlir/Dialect/QCO/Utils/Matrix.h +++ b/mlir/include/mlir/Dialect/QCO/Utils/Matrix.h @@ -25,6 +25,8 @@ using Complex = std::complex; /// Default absolute tolerance for matrix comparisons. inline constexpr double MATRIX_TOLERANCE = 1e-14; +class DynamicMatrix; + /** * @brief 1x1 matrix for global-phase gates. * @@ -58,6 +60,26 @@ struct Matrix1x1 { */ [[nodiscard]] Complex operator()(std::size_t row, std::size_t col) const; + /** + * @brief Element-wise scaling by a complex scalar. + * @param scalar Factor applied to the matrix entry. + * @return Scaled copy of this matrix. + */ + [[nodiscard]] Matrix1x1 operator*(const Complex& scalar) const; + + /** + * @brief Element-wise in-place scaling by a complex scalar. + * @param scalar Factor applied to the matrix entry. + * @return Reference to this matrix. + */ + Matrix1x1& operator*=(const Complex& scalar); + + /** + * @brief Returns the conjugate transpose (adjoint) of this matrix. + * @return Adjoint matrix `A^\dagger`. + */ + [[nodiscard]] Matrix1x1 adjoint() const; + /** * @brief Checks approximate equality using an absolute tolerance. * @param other Matrix to compare against. @@ -66,6 +88,14 @@ struct Matrix1x1 { */ [[nodiscard]] bool isApprox(const Matrix1x1& other, double tol = MATRIX_TOLERANCE) const; + + /** + * @brief Replaces this matrix with a copy of a 1x1 dynamic matrix. + * + * @param src Source matrix. + * @return `true` when @p src is 1x1. + */ + [[nodiscard]] bool assignFrom(const DynamicMatrix& src); }; /** @@ -127,6 +157,26 @@ struct Matrix2x2 { */ [[nodiscard]] Matrix2x2 operator*(const Matrix2x2& rhs) const; + /** + * @brief Premultiplies by a matrix: `*this = lhs * *this`. + * @param lhs Left-hand factor. + */ + void premultiplyBy(const Matrix2x2& lhs); + + /** + * @brief Element-wise scaling by a complex scalar. + * @param scalar Factor applied to every matrix entry. + * @return Scaled copy of this matrix. + */ + [[nodiscard]] Matrix2x2 operator*(const Complex& scalar) const; + + /** + * @brief Element-wise in-place scaling by a complex scalar. + * @param scalar Factor applied to every matrix entry. + * @return Reference to this matrix. + */ + Matrix2x2& operator*=(const Complex& scalar); + /** * @brief Returns the conjugate transpose (adjoint) of this matrix. * @return Adjoint matrix `A^\dagger`. @@ -157,6 +207,14 @@ struct Matrix2x2 { */ [[nodiscard]] bool isApprox(const Matrix2x2& other, double tol = MATRIX_TOLERANCE) const; + + /** + * @brief Replaces this matrix with a copy of a 2x2 dynamic matrix. + * + * @param src Source matrix. + * @return `true` when @p src is 2x2. + */ + [[nodiscard]] bool assignFrom(const DynamicMatrix& src); }; /** @@ -238,6 +296,26 @@ struct Matrix4x4 { */ [[nodiscard]] Matrix4x4 operator*(const Matrix4x4& rhs) const; + /** + * @brief Premultiplies by a matrix: `*this = lhs * *this`. + * @param lhs Left-hand factor. + */ + void premultiplyBy(const Matrix4x4& lhs); + + /** + * @brief Element-wise scaling by a complex scalar. + * @param scalar Factor applied to every matrix entry. + * @return Scaled copy of this matrix. + */ + [[nodiscard]] Matrix4x4 operator*(const Complex& scalar) const; + + /** + * @brief Element-wise in-place scaling by a complex scalar. + * @param scalar Factor applied to every matrix entry. + * @return Reference to this matrix. + */ + Matrix4x4& operator*=(const Complex& scalar); + /** * @brief Returns the conjugate transpose (adjoint) of this matrix. * @return Adjoint matrix `A^\dagger`. @@ -268,6 +346,14 @@ struct Matrix4x4 { */ [[nodiscard]] bool isApprox(const Matrix4x4& other, double tol = MATRIX_TOLERANCE) const; + + /** + * @brief Replaces this matrix with a copy of a 4x4 dynamic matrix. + * + * @param src Source matrix. + * @return `true` when @p src is 4x4. + */ + [[nodiscard]] bool assignFrom(const DynamicMatrix& src); }; /** @@ -288,6 +374,18 @@ class DynamicMatrix { */ explicit DynamicMatrix(std::int64_t dim); + /** + * @brief Creates a dynamic matrix from a fixed 2x2 matrix. + * @param src Source matrix. + */ + explicit DynamicMatrix(const Matrix2x2& src); + + /** + * @brief Creates a dynamic matrix from a fixed 4x4 matrix. + * @param src Source matrix. + */ + explicit DynamicMatrix(const Matrix4x4& src); + /// Copy constructor. DynamicMatrix(const DynamicMatrix& other); /// Move constructor. @@ -306,6 +404,13 @@ class DynamicMatrix { */ [[nodiscard]] static DynamicMatrix identity(std::int64_t dim); + /** + * @brief Creates a dynamic matrix holding the adjoint of a 2x2 matrix. + * @param src Source matrix. + * @return Adjoint matrix `src^\dagger`. + */ + [[nodiscard]] static DynamicMatrix fromAdjoint(const Matrix2x2& src); + /** * @brief Returns the number of rows. * @return Matrix dimension. @@ -386,6 +491,18 @@ class DynamicMatrix { */ void assignFrom(const DynamicMatrix& src); + /** + * @brief Checks approximate equality against a fixed 1x1 matrix. + * + * Returns false if this matrix is not 1x1. + * + * @param other Fixed-size matrix to compare against. + * @param tol Maximum allowed complex modulus of the entry difference. + * @return True if dimensions match and the entry differs by at most @p tol. + */ + [[nodiscard]] bool isApprox(const Matrix1x1& other, + double tol = MATRIX_TOLERANCE) const; + /** * @brief Checks approximate equality against a fixed 2x2 matrix. * diff --git a/mlir/include/mlir/Dialect/QCO/Utils/WireIterator.h b/mlir/include/mlir/Dialect/QCO/Utils/WireIterator.h index 70ba4b26ff..6e5b395413 100644 --- a/mlir/include/mlir/Dialect/QCO/Utils/WireIterator.h +++ b/mlir/include/mlir/Dialect/QCO/Utils/WireIterator.h @@ -116,4 +116,25 @@ template <> struct WireTraversalTraits { : !isa(it.operation()); } }; + +/** + * @brief A range over the def-use chain of a qubit wire, usable in range-based + * for-loops. + * + * Example: + * @code + * for (auto* op : WireRange(qubit)) { ... } + * @endcode + */ +struct WireRange { + explicit WireRange(Value qubit) : begin_(qubit) {} + + [[nodiscard]] WireIterator begin() const { return begin_; } + [[nodiscard]] static std::default_sentinel_t end() { + return std::default_sentinel; + } + +private: + WireIterator begin_; +}; } // namespace mlir::qco diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 91c8d341f4..4135ca1769 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -26,6 +26,8 @@ namespace mlir::utils { +/// Default absolute tolerance for MLIR dialect numerics (angle wrapping, +/// phase-zero checks). constexpr auto TOLERANCE = 1e-15; inline Value constantFromScalar(OpBuilder& builder, Location loc, double v) { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 9c5a46837a..0bec883fae 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -293,6 +293,13 @@ void CtrlOp::getCanonicalizationPatterns(RewritePatternSet& results, results.add(context); } +bool CtrlOp::hasCompileTimeKnownUnitaryMatrix() { + return all_of(getBody()->getOps(), + [](UnitaryOpInterface op) { + return op.hasCompileTimeKnownUnitaryMatrix(); + }); +} + std::optional CtrlOp::getUnitaryMatrix() { auto bodyUnitary = utils::getSoleBodyUnitary(*getBody()); if (!bodyUnitary) { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index 4bb5bf4d9f..c6d6a36679 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -29,6 +29,8 @@ #include #include +#include +#include #include #include #include @@ -403,15 +405,74 @@ void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, CancelNestedInv, EraseEmptyInv>(context); } -std::optional InvOp::getUnitaryMatrix() { - auto bodyUnitary = utils::getSoleBodyUnitary(*getBody()); - if (!bodyUnitary) { +bool InvOp::hasCompileTimeKnownUnitaryMatrix() { + return all_of(getBody()->getOps(), + [](UnitaryOpInterface op) { + return op.hasCompileTimeKnownUnitaryMatrix(); + }); +} + +/** + * @brief Composes compile-time single-qubit unitaries and returns the inverse. + */ +[[nodiscard]] static std::optional +composeInvertedSingleQubitBodyMatrix(Block& block) { + Matrix2x2 acc = Matrix2x2::identity(); + Complex global{1.0, 0.0}; + bool found = false; + for (Operation& op : block.without_terminator()) { + if (!TypeSwitch(&op) + .Case([](auto) { return true; }) + .Case([&](GPhaseOp gphase) { + const auto matrix = gphase.getUnitaryMatrix(); + if (!matrix) { + return false; + } + global *= (*matrix)(0, 0); + return true; + }) + .Case([&](UnitaryOpInterface unitary) { + Matrix2x2 matrix; + if (!unitary.getUnitaryMatrix2x2(matrix)) { + return false; + } + acc.premultiplyBy(matrix); + found = true; + return true; + }) + .Default([](Operation* operation) { + const auto usesQubit = [](Value value) { + return isa(value.getType()); + }; + return !llvm::any_of(operation->getOperands(), usesQubit) && + !llvm::any_of(operation->getResults(), usesQubit); + })) { + return std::nullopt; + } + } + if (!found && std::abs(global - Complex{1.0, 0.0}) <= utils::TOLERANCE) { return std::nullopt; } - const auto targetMatrix = bodyUnitary.getUnitaryMatrix(); - if (!targetMatrix) { + acc *= global; + return DynamicMatrix::fromAdjoint(acc); +} + +std::optional InvOp::getUnitaryMatrix() { + if (getNumBodyUnitaries() == 0) { + return DynamicMatrix::identity(1LL << getNumTargets()); + } + + if (auto bodyUnitary = + utils::getSoleBodyUnitary(*getBody())) { + if (const auto targetMatrix = + bodyUnitary.getUnitaryMatrix()) { + return targetMatrix->adjoint(); + } return std::nullopt; } - return targetMatrix->adjoint(); + if (getNumTargets() != 1) { + return std::nullopt; + } + return composeInvertedSingleQubitBodyMatrix(*getBody()); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp index 1b6b50fd6d..6aed75b199 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/BarrierOp.cpp @@ -9,6 +9,7 @@ */ #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/Utils/Matrix.h" #include #include @@ -109,3 +110,8 @@ void BarrierOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.add(context); } + +DynamicMatrix BarrierOp::getUnitaryMatrix() { + const auto numQubits = getQubitsIn().size(); + return DynamicMatrix::identity(1LL << numQubits); +} diff --git a/mlir/lib/Dialect/QCO/Transforms/Decomposition/Euler.cpp b/mlir/lib/Dialect/QCO/Transforms/Decomposition/Euler.cpp new file mode 100644 index 0000000000..085d493abf --- /dev/null +++ b/mlir/lib/Dialect/QCO/Transforms/Decomposition/Euler.cpp @@ -0,0 +1,444 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QCO/Transforms/Decomposition/Euler.h" + +#include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/Utils/Matrix.h" +#include "mlir/Dialect/Utils/Utils.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace mlir::qco::decomposition { + +/** + * @brief Wraps `angle` into `[-pi, pi)`, mapping `+pi` (within tolerance) to + * `-pi`. + * + * @param angle The angle to wrap, in radians. + * @return The wrapped angle in `[-pi, pi)`. + */ +[[nodiscard]] static double mod2pi(const double angle) { + if (!std::isfinite(angle)) { + return angle; + } + + constexpr double pi = std::numbers::pi; + constexpr double twoPi = 2.0 * std::numbers::pi; + + double r = std::fmod(angle + pi, twoPi); + if (r < 0.0) { + r += twoPi; + } + double wrapped = r - pi; + + if (wrapped >= pi - utils::TOLERANCE) { + wrapped = -pi; + } + + return wrapped; +} + +/** + * @brief Conjugates a single-qubit matrix by Hadamard (`H * m * H`). + * + * Maps XYX / XZX parameterizations to ZYZ / ZXZ. + * + * @param m The single-qubit matrix to conjugate. + * @return `H * m * H`. + */ +[[nodiscard]] static Matrix2x2 hadamardConjugate(const Matrix2x2& m) { + const auto a = m(0, 0); + const auto b = m(0, 1); + const auto c = m(1, 0); + const auto d = m(1, 1); + return Matrix2x2::fromElements(0.5 * (a + b + c + d), 0.5 * (a - b + c - d), + 0.5 * (a + b - c - d), 0.5 * (a - b - c + d)); +} + +/** + * @brief Whether `angle` is numerically zero for gate-emission purposes. + * + * @param angle Rotation angle in radians. + * @return `true` when no rotation gate should be emitted. + */ +[[nodiscard]] static bool isNearZeroRotationAngle(const double angle) { + return std::abs(angle) <= utils::TOLERANCE; +} + +/** + * @brief Emits `qco.gphase` when `phase` is outside tolerance. + * + * @param builder Builder for the operation. + * @param loc Location of the operation. + * @param phase Global phase in radians. + */ +static void emitGPhaseIfNeeded(OpBuilder& builder, Location loc, double phase) { + if (isNearZeroRotationAngle(mod2pi(phase))) { + return; + } + GPhaseOp::create(builder, loc, phase); +} + +//===----------------------------------------------------------------------===// +// Euler decomposition (angles) +//===----------------------------------------------------------------------===// + +/** + * @brief Euler angles `(theta, phi, lambda)` and global phase for a 2x2 + * unitary. + */ +namespace { + +struct EulerAngles { + double theta = 0.0; ///< Middle rotation angle. + double phi = 0.0; ///< First outer rotation angle. + double lambda = 0.0; ///< Second outer rotation angle. + double phase = 0.0; ///< Global phase in radians. +}; + +} // namespace + +/** + * @brief Z-Y-Z Euler angles and global phase for a 2x2 unitary. + * + * @param matrix Single-qubit unitary to decompose. + * @return Z-Y-Z angles and global phase. + */ +[[nodiscard]] static EulerAngles paramsZYZ(const Matrix2x2& matrix) { + // det(U) = exp(2i*phase) + const Complex det = matrix.determinant(); + const auto detArg = std::arg(det); + const auto phase = 0.5 * detArg; + const auto theta = + 2. * std::atan2(std::abs(matrix(1, 0)), std::abs(matrix(0, 0))); + const auto ang1 = std::arg(matrix(1, 1)); + double ang2 = 0.0; + if (std::abs(matrix(1, 0)) > utils::TOLERANCE) { + ang2 = std::arg(matrix(1, 0)); + } else if (std::abs(matrix(0, 1)) > utils::TOLERANCE) { + ang2 = std::arg(matrix(0, 1)); + } + const auto phi = ang1 + ang2 - detArg; + const auto lambda = ang1 - ang2; + return {.theta = theta, .phi = phi, .lambda = lambda, .phase = phase}; +} + +/** + * @brief Z-X-Z Euler angles via `RY(theta) = RZ(pi/2)*RX(theta)*RZ(-pi/2)`. + * + * @param matrix Single-qubit unitary to decompose. + * @return Z-X-Z angles and global phase. + */ +[[nodiscard]] static EulerAngles paramsZXZ(const Matrix2x2& matrix) { + const auto [theta, phi, lambda, phase] = paramsZYZ(matrix); + return {.theta = theta, + .phi = phi + (std::numbers::pi / 2.0), + .lambda = lambda - (std::numbers::pi / 2.0), + .phase = phase}; +} + +/** + * @brief X-Z-X Euler angles (Z-X-Z under H conjugation). + * + * @param matrix Single-qubit unitary to decompose. + * @return X-Z-X angles and global phase. + */ +[[nodiscard]] static EulerAngles paramsXZX(const Matrix2x2& matrix) { + return paramsZXZ(hadamardConjugate(matrix)); +} + +/** + * @brief X-Y-X Euler angles via `H*RY(theta)*H = RY(-theta)`. + * + * @param matrix Single-qubit unitary to decompose. + * @return X-Y-X angles and global phase. + */ +[[nodiscard]] static EulerAngles paramsXYX(const Matrix2x2& matrix) { + // Shift outer angles by pi and fix global phase. + const auto [theta, phi, lambda, phase] = paramsZYZ(hadamardConjugate(matrix)); + return {.theta = theta, + .phi = phi + std::numbers::pi, + .lambda = lambda + std::numbers::pi, + .phase = phase + std::numbers::pi}; +} + +/** + * @brief `U`-basis angles (Z-Y-Z angles with a `U`-vs-`RZ·RY·RZ` phase fix). + * + * @param matrix Single-qubit unitary to decompose. + * @return `U`-gate angles and global phase. + */ +[[nodiscard]] static EulerAngles paramsU(const Matrix2x2& matrix) { + // `U` differs from RZ(phi)*RY(theta)*RZ(lambda) by a global phase of + // -(phi + lambda)/2. + const auto [theta, phi, lambda, phase] = paramsZYZ(matrix); + return {.theta = theta, + .phi = phi, + .lambda = lambda, + .phase = phase - (0.5 * (phi + lambda))}; +} + +/** + * @brief Extracts `(theta, phi, lambda, phase)` for all Euler bases. + * + * @param matrix The single-qubit unitary to decompose. + * @param basis The target Euler basis. + * @return The extracted Euler angles and global phase. + */ +[[nodiscard]] static EulerAngles anglesFromUnitary(const Matrix2x2& matrix, + const EulerBasis basis) { + switch (basis) { + case EulerBasis::ZYZ: + case EulerBasis::ZSXX: + return paramsZYZ(matrix); + case EulerBasis::ZXZ: + return paramsZXZ(matrix); + case EulerBasis::XZX: + return paramsXZX(matrix); + case EulerBasis::XYX: + return paramsXYX(matrix); + case EulerBasis::U: + return paramsU(matrix); + default: + llvm_unreachable("invalid Euler basis"); + } +} + +//===----------------------------------------------------------------------===// +// Euler synthesis (plan + emit) +//===----------------------------------------------------------------------===// + +namespace { + +/** + * @brief One gate in a planned single-qubit synthesis sequence. + * + * `RZ`/`RY`/`RX` use @p theta as the rotation angle; `U` uses all three angles. + */ +struct SynthesisStep { + enum class Kind : std::uint8_t { RZ, RY, RX, SX, X, U }; + + Kind kind = Kind::RZ; + double theta = 0.0; + double phi = 0.0; + double lambda = 0.0; +}; + +/** @brief Planned single-qubit Euler synthesis (gate list + optional `gphase`). + */ +struct Unitary1QEulerPlan { + SmallVector steps; + double phase = 0.0; + + /// @brief Number of native gates in the planned sequence (excludes `gphase`). + [[nodiscard]] std::size_t gateCount() const { return steps.size(); } + + /** + * @brief Appends a rotation step for non-negligible angles. + * + * @param kind The rotation axis (RZ/RY/RX) + * @param angle The rotation angle in radians. + */ + void appendRotation(const SynthesisStep::Kind kind, const double angle) { + if (!isNearZeroRotationAngle(angle)) { + steps.emplace_back(kind, angle); + } + } + + /** + * @brief Appends the decomposition for @p basis based on @p angles. + * + * @param angles The angles to use for the decomposition. + * @param basis The basis to use for the decomposition. + */ + void appendDecomposition(const EulerAngles& angles, const EulerBasis basis) { + if (isNearZeroRotationAngle(angles.theta) && + isNearZeroRotationAngle(angles.phi) && + isNearZeroRotationAngle(angles.lambda)) { + phase = angles.phase; + return; + } + + if (isNearZeroRotationAngle(angles.theta)) { + switch (basis) { + case EulerBasis::ZYZ: + case EulerBasis::ZXZ: + case EulerBasis::ZSXX: + appendRotation(SynthesisStep::Kind::RZ, angles.phi + angles.lambda); + break; + + case EulerBasis::XZX: + case EulerBasis::XYX: + appendRotation(SynthesisStep::Kind::RX, angles.phi + angles.lambda); + break; + case EulerBasis::U: + steps.emplace_back(SynthesisStep::Kind::U, 0.0, angles.phi, + angles.lambda); + break; + } + phase = angles.phase; + return; + } + + switch (basis) { + case EulerBasis::ZYZ: + appendRotation(SynthesisStep::Kind::RZ, angles.lambda); + steps.emplace_back(SynthesisStep::Kind::RY, angles.theta); + appendRotation(SynthesisStep::Kind::RZ, angles.phi); + phase = angles.phase; + break; + case EulerBasis::ZXZ: + appendRotation(SynthesisStep::Kind::RZ, angles.lambda); + steps.emplace_back(SynthesisStep::Kind::RX, angles.theta); + appendRotation(SynthesisStep::Kind::RZ, angles.phi); + phase = angles.phase; + break; + case EulerBasis::XZX: + appendRotation(SynthesisStep::Kind::RX, angles.lambda); + steps.emplace_back(SynthesisStep::Kind::RZ, angles.theta); + appendRotation(SynthesisStep::Kind::RX, angles.phi); + phase = angles.phase; + break; + case EulerBasis::XYX: + appendRotation(SynthesisStep::Kind::RX, angles.lambda); + steps.emplace_back(SynthesisStep::Kind::RY, angles.theta); + appendRotation(SynthesisStep::Kind::RX, angles.phi); + phase = angles.phase; + break; + case EulerBasis::U: + steps.emplace_back(SynthesisStep::Kind::U, angles.theta, angles.phi, + angles.lambda); + phase = angles.phase; + break; + case EulerBasis::ZSXX: { + constexpr double pi = std::numbers::pi; + constexpr double halfPi = std::numbers::pi / 2.0; + constexpr double quarterPi = std::numbers::pi / 4.0; + + if (isNearZeroRotationAngle(angles.theta - halfPi)) { + appendRotation(SynthesisStep::Kind::RZ, angles.lambda - halfPi); + steps.emplace_back(SynthesisStep::Kind::SX); + appendRotation(SynthesisStep::Kind::RZ, angles.phi + halfPi); + phase = angles.phase - quarterPi; + return; + } + + appendRotation(SynthesisStep::Kind::RZ, angles.lambda); + if (isNearZeroRotationAngle(angles.theta - pi)) { + steps.emplace_back(SynthesisStep::Kind::X); + phase = angles.phase - halfPi; + } else { + steps.emplace_back(SynthesisStep::Kind::SX); + appendRotation(SynthesisStep::Kind::RZ, angles.theta + pi); + steps.emplace_back(SynthesisStep::Kind::SX); + phase = angles.phase + halfPi; + } + appendRotation(SynthesisStep::Kind::RZ, angles.phi + pi); + break; + } + } + } +}; +} // namespace + +/** + * @brief Builds a gate plan for @p targetMatrix in @p basis without emitting + * IR. + * + * @param targetMatrix Single-qubit unitary to synthesize. + * @param basis Native gate basis. + * @return Planned gate sequence and optional global phase. + */ +[[nodiscard]] static Unitary1QEulerPlan +planUnitary1QEuler(const Matrix2x2& targetMatrix, const EulerBasis basis) { + Unitary1QEulerPlan plan; + if (targetMatrix.isApprox(Matrix2x2::identity())) { + return plan; + } + + const EulerAngles angles = anglesFromUnitary(targetMatrix, basis); + plan.appendDecomposition(angles, basis); + return plan; +} + +/** + * @brief Emits the gates described by @p plan and returns the output qubit. + * + * @param builder Builder for the emitted operations. + * @param loc Location for the emitted operations. + * @param qubit Input qubit value. + * @param plan Precomputed synthesis plan. + * @return Qubit value after all planned gates (and `gphase` when needed). + */ +[[nodiscard]] static Value +emitUnitary1QEulerPlan(OpBuilder& builder, Location loc, Value qubit, + const Unitary1QEulerPlan& plan) { + for (const auto& [kind, theta, phi, lambda] : plan.steps) { + switch (kind) { + case SynthesisStep::Kind::RZ: + qubit = RZOp::create(builder, loc, qubit, theta).getQubitOut(); + break; + case SynthesisStep::Kind::RY: + qubit = RYOp::create(builder, loc, qubit, theta).getQubitOut(); + break; + case SynthesisStep::Kind::RX: + qubit = RXOp::create(builder, loc, qubit, theta).getQubitOut(); + break; + case SynthesisStep::Kind::SX: + qubit = SXOp::create(builder, loc, qubit).getQubitOut(); + break; + case SynthesisStep::Kind::X: + qubit = XOp::create(builder, loc, qubit).getQubitOut(); + break; + case SynthesisStep::Kind::U: + qubit = + UOp::create(builder, loc, qubit, theta, phi, lambda).getQubitOut(); + break; + } + } + emitGPhaseIfNeeded(builder, loc, plan.phase); + return qubit; +} + +std::optional parseEulerBasis(StringRef basis) { + return StringSwitch>(basis.lower()) + .Case("zyz", EulerBasis::ZYZ) + .Case("zxz", EulerBasis::ZXZ) + .Case("xzx", EulerBasis::XZX) + .Case("xyx", EulerBasis::XYX) + .Case("u", EulerBasis::U) + .Case("zsxx", EulerBasis::ZSXX) + .Default(std::nullopt); +} + +std::optional +synthesizeUnitary1QEuler(OpBuilder& builder, Location loc, Value qubit, + const Matrix2x2& composed, const std::size_t runSize, + const bool hasNonBasisGate, const EulerBasis basis) { + const Unitary1QEulerPlan plan = planUnitary1QEuler(composed, basis); + if (!hasNonBasisGate && runSize <= plan.gateCount()) { + return std::nullopt; + } + return emitUnitary1QEulerPlan(builder, loc, qubit, plan); +} + +} // namespace mlir::qco::decomposition diff --git a/mlir/lib/Dialect/QCO/Transforms/NativeSynthesis/FuseSingleQubitUnitaryRuns.cpp b/mlir/lib/Dialect/QCO/Transforms/NativeSynthesis/FuseSingleQubitUnitaryRuns.cpp new file mode 100644 index 0000000000..9a91146384 --- /dev/null +++ b/mlir/lib/Dialect/QCO/Transforms/NativeSynthesis/FuseSingleQubitUnitaryRuns.cpp @@ -0,0 +1,220 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QCO/IR/QCOInterfaces.h" +#include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/Transforms/Decomposition/Euler.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" +#include "mlir/Dialect/QCO/Utils/Matrix.h" +#include "mlir/Dialect/QCO/Utils/WireIterator.h" + +#include +#include // IWYU pragma: keep (Passes.h.inc) +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace mlir::qco { + +#define GEN_PASS_DEF_FUSESINGLEQUBITUNITARYRUNS +#include "mlir/Dialect/QCO/Transforms/Passes.h.inc" + +namespace { + +/** Composed unitary and metadata for a fusable run. */ +struct FusableRunScan { + Matrix2x2 composed = Matrix2x2::identity(); + std::size_t gateCount = 0; + bool hasNonBasisGate = false; + UnitaryOpInterface tail; +}; + +} // namespace + +/** + * @brief Whether `gate` can take part in a fusable single-qubit run. + */ +static bool isRunMember(UnitaryOpInterface gate) { + if (!gate || !gate.isSingleQubit() || isa(gate.getOperation())) { + return false; + } + return gate.hasCompileTimeKnownUnitaryMatrix(); +} + +/** + * @brief Whether `op` is a gate that Euler synthesis emits for `basis`. + * + * @param op The operation to classify. + * @param basis The target Euler basis. + * @return Whether `op` is in the gate set for `basis`. + */ +static bool isTargetBasisGate(Operation* op, + const decomposition::EulerBasis basis) { + using decomposition::EulerBasis; + return TypeSwitch(op) + .Case([&](auto) { + return basis == EulerBasis::ZYZ || basis == EulerBasis::ZXZ || + basis == EulerBasis::XZX || basis == EulerBasis::ZSXX; + }) + .Case([&](auto) { + return basis == EulerBasis::ZYZ || basis == EulerBasis::XYX; + }) + .Case([&](auto) { + return basis == EulerBasis::ZXZ || basis == EulerBasis::XZX || + basis == EulerBasis::XYX; + }) + .Case([&](auto) { return basis == EulerBasis::U; }) + .Case([&](auto) { return basis == EulerBasis::ZSXX; }) + .Default([](auto) { return false; }); +} + +/** + * @brief Walks the wire from @p head, composing the run's matrix and metadata. + * + * @param head First gate of the run. + * @param basis Target Euler basis. + * @return Composed matrix, gate count, and run tail. + */ +static FusableRunScan scanFusableRun(UnitaryOpInterface head, + const decomposition::EulerBasis basis) { + FusableRunScan scan; + for (auto* op : WireRange(head.getOutputTarget(0))) { + auto member = dyn_cast_or_null(op); + if (!member || !isRunMember(member)) { + break; + } + const auto matrix = member.getUnitaryMatrix(); + assert(matrix && "run member must have a compile-time 2x2 matrix"); + scan.composed.premultiplyBy(*matrix); + scan.hasNonBasisGate |= !isTargetBasisGate(op, basis); + scan.tail = member; + ++scan.gateCount; + } + return scan; +} + +/** + * @brief Erases a contiguous run from @p tail back to @p head. + * + * @param rewriter The pattern rewriter. + * @param head First gate of the run. + * @param tail Last gate of the run. + */ +static void eraseFusableRun(PatternRewriter& rewriter, UnitaryOpInterface head, + UnitaryOpInterface tail) { + // Tail-first: each erased op is dead once its successor is gone. + auto it = WireIterator(tail.getOutputTarget(0)); + auto* target = head.getOperation(); + while (*it != target) { + auto* current = *it; + --it; + rewriter.eraseOp(current); + } + rewriter.eraseOp(target); +} + +namespace { + +/** + * @brief Fuses maximal single-qubit unitary runs via Euler resynthesis. + */ +struct FuseSingleQubitUnitaryRunsPattern final + : OpInterfaceRewritePattern { + FuseSingleQubitUnitaryRunsPattern(MLIRContext* context, + const decomposition::EulerBasis basis) + : OpInterfaceRewritePattern(context), basis(basis) {} + + decomposition::EulerBasis basis; + + /** + * @brief Whether `op` starts a run. + * + * @param op The candidate run head. + * @return Whether `op` anchors a maximal fusable run. + */ + static bool isRunStart(UnitaryOpInterface op) { + return isRunMember(op) && !isRunMember(dyn_cast_or_null( + op.getInputTarget(0).getDefiningOp())); + } + + /** + * @brief Fuses the run anchored at `op` when beneficial. + * + * Fuses if the run contains a non-basis gate or Euler resynthesis would + * shorten it (@ref synthesizeUnitary1QEuler). + * + * @param op The matched unitary operation. + * @param rewriter The pattern rewriter. + * @return `success()` if a run was fused, `failure()` otherwise. + */ + LogicalResult matchAndRewrite(UnitaryOpInterface op, + PatternRewriter& rewriter) const override { + if (!isRunStart(op)) { + return failure(); + } + + FusableRunScan run = scanFusableRun(op, basis); + const auto qubitOut = decomposition::synthesizeUnitary1QEuler( + rewriter, op.getLoc(), op.getInputTarget(0), run.composed, + run.gateCount, run.hasNonBasisGate, basis); + if (!qubitOut) { + return failure(); + } + + rewriter.replaceAllUsesWith(run.tail.getOutputTarget(0), *qubitOut); + eraseFusableRun(rewriter, op, run.tail); + return success(); + } +}; + +/** + * @brief Pass that fuses single-qubit unitary runs via Euler resynthesis. + */ +struct FuseSingleQubitUnitaryRunsPass final + : impl::FuseSingleQubitUnitaryRunsBase { + using Base::Base; + + explicit FuseSingleQubitUnitaryRunsPass( + FuseSingleQubitUnitaryRunsOptions options) + : Base(std::move(options)) {} + +protected: + void runOnOperation() override { + auto module = getOperation(); + + const auto parsed = decomposition::parseEulerBasis(basis); + if (!parsed) { + module.emitError() << "Invalid Euler basis '" << basis + << "'. Expected one of: zyz, zxz, xzx, xyx, u, zsxx."; + signalPassFailure(); + return; + } + + RewritePatternSet patterns(&getContext()); + patterns.add(patterns.getContext(), + *parsed); + + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace mlir::qco diff --git a/mlir/lib/Dialect/QCO/Utils/CMakeLists.txt b/mlir/lib/Dialect/QCO/Utils/CMakeLists.txt index 6e0d1eed9c..c671a9f564 100644 --- a/mlir/lib/Dialect/QCO/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/QCO/Utils/CMakeLists.txt @@ -27,7 +27,8 @@ add_mlir_dialect_library( MLIRQCOInterfacesIncGen LINK_LIBS PUBLIC - MLIRQCODialect) + MLIRQCODialect + MLIRSCFDialect) mqt_mlir_target_use_project_options(MLIRQCOUtils) diff --git a/mlir/lib/Dialect/QCO/Utils/Matrix.cpp b/mlir/lib/Dialect/QCO/Utils/Matrix.cpp index 01d1b35ea4..8df45840a3 100644 --- a/mlir/lib/Dialect/QCO/Utils/Matrix.cpp +++ b/mlir/lib/Dialect/QCO/Utils/Matrix.cpp @@ -55,7 +55,7 @@ static void assignFixedImpl(std::int64_t& dim, SmallVector& data, template [[nodiscard]] static bool -isApproxFixedImpl(const std::int64_t dim, ArrayRef data, +isApproxFixedImpl(const std::int64_t dim, const SmallVector& data, const std::array& other, const double tol) { if (std::cmp_not_equal(dim, Dim)) { return false; @@ -63,6 +63,52 @@ isApproxFixedImpl(const std::int64_t dim, ArrayRef data, return entriesAreApprox(data, other, tol); } +template +[[nodiscard]] static bool +assignFromDynamicImpl(const DynamicMatrix& src, + std::array& dst) { + if (src.rows() != static_cast(Dim) || + src.cols() != static_cast(Dim)) { + return false; + } + for (std::size_t row = 0; row < Dim; ++row) { + for (std::size_t col = 0; col < Dim; ++col) { + dst[(row * Dim) + col] = + src(static_cast(row), static_cast(col)); + } + } + return true; +} + +/// Writes the row-major product `lhs * rhs` into @p out (2x2, fully unrolled). +static void +multiply2x2(const std::array& lhs, + const std::array& rhs, + std::array& out) { + out[0] = lhs[0] * rhs[0] + lhs[1] * rhs[2]; + out[1] = lhs[0] * rhs[1] + lhs[1] * rhs[3]; + out[2] = lhs[2] * rhs[0] + lhs[3] * rhs[2]; + out[3] = lhs[2] * rhs[1] + lhs[3] * rhs[3]; +} + +/// Writes the row-major product `lhs * rhs` into @p out (4x4, unrolled rows). +static void +multiply4x4(const std::array& lhs, + const std::array& rhs, + std::array& out) { + for (std::size_t row = 0; row < Matrix4x4::K_ROWS; ++row) { + const std::size_t rowBase = row * Matrix4x4::K_COLS; + const Complex& a0 = lhs[rowBase + 0]; + const Complex& a1 = lhs[rowBase + 1]; + const Complex& a2 = lhs[rowBase + 2]; + const Complex& a3 = lhs[rowBase + 3]; + out[rowBase + 0] = a0 * rhs[0] + a1 * rhs[4] + a2 * rhs[8] + a3 * rhs[12]; + out[rowBase + 1] = a0 * rhs[1] + a1 * rhs[5] + a2 * rhs[9] + a3 * rhs[13]; + out[rowBase + 2] = a0 * rhs[2] + a1 * rhs[6] + a2 * rhs[10] + a3 * rhs[14]; + out[rowBase + 3] = a0 * rhs[3] + a1 * rhs[7] + a2 * rhs[11] + a3 * rhs[15]; + } +} + /// Returns @p dim as `size_t` after asserting it is non-negative and squarable. [[nodiscard]] static std::size_t checkedDim(const std::int64_t dim) { assert(dim >= 0 && "DynamicMatrix dimension must be non-negative"); @@ -123,6 +169,25 @@ bool Matrix1x1::isApprox(const Matrix1x1& other, const double tol) const { return std::abs(value - other.value) <= tol; } +bool Matrix1x1::assignFrom(const DynamicMatrix& src) { + if (src.rows() != 1 || src.cols() != 1) { + return false; + } + value = src(0, 0); + return true; +} + +Matrix1x1 Matrix1x1::operator*(const Complex& scalar) const { + return fromElements(value * scalar); +} + +Matrix1x1& Matrix1x1::operator*=(const Complex& scalar) { + value *= scalar; + return *this; +} + +Matrix1x1 Matrix1x1::adjoint() const { return fromElements(std::conj(value)); } + Matrix2x2 Matrix2x2::fromElements(const Complex& m00, const Complex& m01, const Complex& m10, const Complex& m11) { return {{m00, m01, m10, m11}}; @@ -138,10 +203,27 @@ Complex Matrix2x2::operator()(const std::size_t row, } Matrix2x2 Matrix2x2::operator*(const Matrix2x2& rhs) const { - return fromElements(data[0] * rhs.data[0] + data[1] * rhs.data[2], - data[0] * rhs.data[1] + data[1] * rhs.data[3], - data[2] * rhs.data[0] + data[3] * rhs.data[2], - data[2] * rhs.data[1] + data[3] * rhs.data[3]); + Matrix2x2 out{}; + multiply2x2(data, rhs.data, out.data); + return out; +} + +void Matrix2x2::premultiplyBy(const Matrix2x2& lhs) { + const std::array rhs = data; + multiply2x2(lhs.data, rhs, data); +} + +Matrix2x2 Matrix2x2::operator*(const Complex& scalar) const { + Matrix2x2 out = *this; + out *= scalar; + return out; +} + +Matrix2x2& Matrix2x2::operator*=(const Complex& scalar) { + for (Complex& entry : data) { + entry *= scalar; + } + return *this; } Matrix2x2 Matrix2x2::adjoint() const { @@ -159,6 +241,10 @@ bool Matrix2x2::isApprox(const Matrix2x2& other, const double tol) const { return entriesAreApprox(data, other.data, tol); } +bool Matrix2x2::assignFrom(const DynamicMatrix& src) { + return assignFromDynamicImpl(src, data); +} + Matrix4x4 Matrix4x4::fromElements(const Complex& m00, const Complex& m01, const Complex& m02, const Complex& m03, const Complex& m10, const Complex& m11, @@ -182,24 +268,28 @@ Complex Matrix4x4::operator()(const std::size_t row, Matrix4x4 Matrix4x4::operator*(const Matrix4x4& rhs) const { Matrix4x4 out{}; - for (std::size_t row = 0; row < K_ROWS; ++row) { - const std::size_t rowBase = row * K_COLS; - const Complex& a0 = data[rowBase + 0]; - const Complex& a1 = data[rowBase + 1]; - const Complex& a2 = data[rowBase + 2]; - const Complex& a3 = data[rowBase + 3]; - out.data[rowBase + 0] = a0 * rhs.data[0] + a1 * rhs.data[4] + - a2 * rhs.data[8] + a3 * rhs.data[12]; - out.data[rowBase + 1] = a0 * rhs.data[1] + a1 * rhs.data[5] + - a2 * rhs.data[9] + a3 * rhs.data[13]; - out.data[rowBase + 2] = a0 * rhs.data[2] + a1 * rhs.data[6] + - a2 * rhs.data[10] + a3 * rhs.data[14]; - out.data[rowBase + 3] = a0 * rhs.data[3] + a1 * rhs.data[7] + - a2 * rhs.data[11] + a3 * rhs.data[15]; - } + multiply4x4(data, rhs.data, out.data); + return out; +} + +void Matrix4x4::premultiplyBy(const Matrix4x4& lhs) { + const std::array rhs = data; + multiply4x4(lhs.data, rhs, data); +} + +Matrix4x4 Matrix4x4::operator*(const Complex& scalar) const { + Matrix4x4 out = *this; + out *= scalar; return out; } +Matrix4x4& Matrix4x4::operator*=(const Complex& scalar) { + for (Complex& entry : data) { + entry *= scalar; + } + return *this; +} + Matrix4x4 Matrix4x4::adjoint() const { Matrix4x4 out{}; adjointInto(data, out.data, K_ROWS); @@ -231,6 +321,10 @@ bool Matrix4x4::isApprox(const Matrix4x4& other, const double tol) const { return entriesAreApprox(data, other.data, tol); } +bool Matrix4x4::assignFrom(const DynamicMatrix& src) { + return assignFromDynamicImpl(src, data); +} + DynamicMatrix::DynamicMatrix() : impl_(std::make_unique()) {} DynamicMatrix::DynamicMatrix(const std::int64_t dim) @@ -239,6 +333,16 @@ DynamicMatrix::DynamicMatrix(const std::int64_t dim) impl_->data.assign(checkedStorageSize(dim), Complex{}); } +DynamicMatrix::DynamicMatrix(const Matrix2x2& src) + : impl_(std::make_unique()) { + assignFrom(src); +} + +DynamicMatrix::DynamicMatrix(const Matrix4x4& src) + : impl_(std::make_unique()) { + assignFrom(src); +} + DynamicMatrix::DynamicMatrix(const DynamicMatrix& other) : impl_(std::make_unique(*other.impl_)) {} @@ -269,6 +373,10 @@ DynamicMatrix DynamicMatrix::identity(const std::int64_t dim) { return matrix; } +DynamicMatrix DynamicMatrix::fromAdjoint(const Matrix2x2& src) { + return DynamicMatrix(src.adjoint()); +} + Complex& DynamicMatrix::operator()(const std::int64_t row, const std::int64_t col) { return impl_->data[static_cast((row * impl_->dim) + col)]; @@ -321,6 +429,13 @@ void DynamicMatrix::assignFrom(const DynamicMatrix& src) { *impl_ = *src.impl_; } +bool DynamicMatrix::isApprox(const Matrix1x1& other, const double tol) const { + if (impl_->dim != 1) { + return false; + } + return std::abs(impl_->data[0] - other.value) <= tol; +} + bool DynamicMatrix::isApprox(const Matrix2x2& other, const double tol) const { return isApproxFixedImpl( diff --git a/mlir/lib/Dialect/QCO/Utils/WireIterator.cpp b/mlir/lib/Dialect/QCO/Utils/WireIterator.cpp index 45522cd8bf..bb7bb6932f 100644 --- a/mlir/lib/Dialect/QCO/Utils/WireIterator.cpp +++ b/mlir/lib/Dialect/QCO/Utils/WireIterator.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -25,8 +26,10 @@ namespace mlir::qco { Value WireIterator::qubit() const { - // A sink/deallocation/insert doesn't have an OpResult. - if (op_ != nullptr && (isa(op_))) { + // Boundary ops (sink/deallocation/insert/yield) consume the wire via an + // operand and have no OpResult, matching the boundaries in forward/backward. + if (op_ != nullptr && + (isa(op_))) { return nullptr; } return qubit_; @@ -42,8 +45,9 @@ void WireIterator::forward() { assert(qubit_.hasOneUse() && "expected linear typing"); op_ = *(qubit_.user_begin()); - // A sink/insert defines the end of the qubit wire (dynamic and static). - if (isa(op_)) { + // A sink/insert/yield or region entry defines the end of the qubit wire. + if (isa(op_)) { isSentinel_ = true; return; } @@ -70,9 +74,10 @@ void WireIterator::backward() { return; } - // For sinks/deallocations/inserts, qubit_ is an OpOperand. Hence, only get - // the def-op. - if (isa(op_)) { + // For sinks/deallocations/inserts/yields, qubit_ is an OpOperand. Hence, only + // get the def-op. + if (isa(op_)) { op_ = qubit_.getDefiningOp(); return; } diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir_matrix.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir_matrix.cpp index 91c31f9899..e196b6c1b8 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir_matrix.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir_matrix.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -115,6 +116,120 @@ TEST_F(QCOMatrixTest, InverseIswapOpMatrix) { ASSERT_TRUE(matrix->isApprox(expected)); } + +TEST_F(QCOMatrixTest, InverseTwoXOpMatrix) { + auto moduleOp = QCOProgramBuilder::build(context.get(), inverseTwoX); + ASSERT_TRUE(moduleOp); + + auto funcOp = *moduleOp->getBody()->getOps().begin(); + auto invOp = *funcOp.getBody().getOps().begin(); + const auto matrix = invOp.getUnitaryMatrix(); + ASSERT_TRUE(matrix); + + DynamicMatrix expected; + expected.assignFrom(Matrix2x2::identity()); + ASSERT_TRUE(matrix->isApprox(expected)); +} + +TEST_F(QCOMatrixTest, InverseXOpMatrix) { + auto moduleOp = QCOProgramBuilder::build(context.get(), inverseX); + ASSERT_TRUE(moduleOp); + + auto funcOp = *moduleOp->getBody()->getOps().begin(); + auto invOp = *funcOp.getBody().getOps().begin(); + const auto matrix = invOp.getUnitaryMatrix(); + ASSERT_TRUE(matrix); + + DynamicMatrix expected; + expected.assignFrom(XOp::getUnitaryMatrix()); + ASSERT_TRUE(matrix->isApprox(expected)); +} + +TEST_F(QCOMatrixTest, InverseSxOpMatrix) { + auto moduleOp = QCOProgramBuilder::build(context.get(), inverseSx); + ASSERT_TRUE(moduleOp); + + auto funcOp = *moduleOp->getBody()->getOps().begin(); + auto invOp = *funcOp.getBody().getOps().begin(); + const auto matrix = invOp.getUnitaryMatrix(); + ASSERT_TRUE(matrix); + + DynamicMatrix expected; + expected.assignFrom(SXdgOp::getUnitaryMatrix()); + ASSERT_TRUE(matrix->isApprox(expected)); +} + +TEST_F(QCOMatrixTest, InverseGphaseXOpMatrix) { + auto moduleOp = QCOProgramBuilder::build(context.get(), inverseGphaseX); + ASSERT_TRUE(moduleOp); + + auto funcOp = *moduleOp->getBody()->getOps().begin(); + auto invOp = *funcOp.getBody().getOps().begin(); + const auto matrix = invOp.getUnitaryMatrix(); + ASSERT_TRUE(matrix); + + const auto composeGlobal = std::polar(1.0, -0.123); + const Matrix2x2 body = XOp::getUnitaryMatrix() * composeGlobal; + + ASSERT_TRUE(matrix->isApprox(DynamicMatrix::fromAdjoint(body))); +} + +TEST_F(QCOMatrixTest, InverseGphaseBarrierOpMatrix) { + auto moduleOp = QCOProgramBuilder::build(context.get(), inverseGphaseBarrier); + ASSERT_TRUE(moduleOp); + + auto funcOp = *moduleOp->getBody()->getOps().begin(); + auto invOp = *funcOp.getBody().getOps().begin(); + const auto matrix = invOp.getUnitaryMatrix(); + ASSERT_TRUE(matrix); + + const auto global = std::conj(std::polar(1.0, 0.123)); + DynamicMatrix expected; + expected.assignFrom(Matrix2x2::fromElements(global, 0, 0, global)); + ASSERT_TRUE(matrix->isApprox(expected)); +} + +TEST_F(QCOMatrixTest, InverseTwoBarriersInInvOpMatrix) { + auto moduleOp = + QCOProgramBuilder::build(context.get(), inverseTwoBarriersInInv); + ASSERT_TRUE(moduleOp); + + auto funcOp = *moduleOp->getBody()->getOps().begin(); + auto invOp = *funcOp.getBody().getOps().begin(); + EXPECT_FALSE(invOp.getUnitaryMatrix()); +} + +TEST_F(QCOMatrixTest, InvTwoOpMatrix) { + auto moduleOp = QCOProgramBuilder::build(context.get(), invTwo); + ASSERT_TRUE(moduleOp); + + auto funcOp = *moduleOp->getBody()->getOps().begin(); + auto invOp = *funcOp.getBody().getOps().begin(); + EXPECT_FALSE(invOp.getUnitaryMatrix()); +} + +TEST_F(QCOMatrixTest, InverseDynamicRzXOpMatrix) { + constexpr auto mlirCode = R"( + module { + func.func @test(%theta: f64) -> !qco.qubit { + %q_in = qco.alloc : !qco.qubit + %q_out = qco.inv (%q = %q_in) { + %q_1 = qco.rz(%theta) %q : !qco.qubit -> !qco.qubit + %q_2 = qco.x %q_1 : !qco.qubit -> !qco.qubit + qco.yield %q_2 : !qco.qubit + } : {!qco.qubit} -> {!qco.qubit} + return %q_out : !qco.qubit + } + } + )"; + + auto moduleOp = parseSourceString(mlirCode, context.get()); + ASSERT_TRUE(moduleOp); + + auto funcOp = *moduleOp->getBody()->getOps().begin(); + auto invOp = *funcOp.getBody().getOps().begin(); + EXPECT_FALSE(invOp.getUnitaryMatrix()); +} /// @} /// \name QCO/Operations/StandardGates/DcxOp.cpp diff --git a/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt index 9f9b03449d..d59780f461 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt +++ b/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt @@ -6,5 +6,6 @@ # # Licensed under the MIT License +add_subdirectory(Decomposition) add_subdirectory(Mapping) add_subdirectory(Optimizations) diff --git a/mlir/unittests/Dialect/QCO/Transforms/Decomposition/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/Decomposition/CMakeLists.txt new file mode 100644 index 0000000000..f493bb9e4d --- /dev/null +++ b/mlir/unittests/Dialect/QCO/Transforms/Decomposition/CMakeLists.txt @@ -0,0 +1,19 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(target_name mqt-core-mlir-unittest-decomposition) +add_executable(${target_name} test_euler_decomposition.cpp) + +target_link_libraries(${target_name} PRIVATE GTest::gtest_main MLIRQCOProgramBuilder + MLIRQCOTransforms) +target_link_libraries(${target_name} PRIVATE MLIRPass MLIRFuncDialect MLIRArithDialect MLIRIR + MLIRSupport MLIRQTensorDialect) + +mqt_mlir_configure_unittest_target(${target_name}) + +gtest_discover_tests(${target_name} PROPERTIES LABELS mqt-mlir-unittests DISCOVERY_TIMEOUT 60) diff --git a/mlir/unittests/Dialect/QCO/Transforms/Decomposition/test_euler_decomposition.cpp b/mlir/unittests/Dialect/QCO/Transforms/Decomposition/test_euler_decomposition.cpp new file mode 100644 index 0000000000..fa9ba1d4ea --- /dev/null +++ b/mlir/unittests/Dialect/QCO/Transforms/Decomposition/test_euler_decomposition.cpp @@ -0,0 +1,1042 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" +#include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QCO/IR/QCOInterfaces.h" +#include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/Transforms/Decomposition/Euler.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" +#include "mlir/Dialect/QCO/Utils/Matrix.h" +#include "mlir/Dialect/Utils/Utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::qco; +using namespace mlir::qco::decomposition; +using enum EulerBasis; + +// File layout: +// 1. Fixtures and parametric test types +// 2. Euler synthesis support + tests +// 3. FuseSingleQubitUnitaryRuns support + tests + +namespace { + +struct TestFixture { + std::unique_ptr context; + + void setUp() { + DialectRegistry registry; + registry.insert(); + context = std::make_unique(); + context->appendDialectRegistry(registry); + context->loadAllAvailableDialects(); + } + + [[nodiscard]] MLIRContext* ctx() const { return context.get(); } +}; + +struct ZSXXShortcutCase { + std::string_view label; + std::function makeMatrix; + std::size_t expectedRZ; + std::size_t expectedSX; + std::size_t expectedX; +}; + +class ZSXXShortcutTest : public testing::TestWithParam {}; + +struct SynthesizedCircuit { + OwningOpRef mlirModule; + func::FuncOp func; +}; + +class EulerSynthesisExactTest + : public testing::TestWithParam< + std::tuple> {}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Euler synthesis support +//===----------------------------------------------------------------------===// + +[[nodiscard]] static Matrix2x2 rzMatrix(const double theta) { + const auto m00 = std::polar(1.0, -theta / 2.0); + const auto m11 = std::polar(1.0, theta / 2.0); + return Matrix2x2::fromElements(m00, 0, 0, m11); +} + +[[nodiscard]] static Matrix2x2 ryMatrix(const double theta) { + const auto m00 = std::cos(theta / 2.0); + const auto m01 = -std::sin(theta / 2.0); + return Matrix2x2::fromElements(m00, m01, -m01, m00); +} + +[[nodiscard]] static Matrix2x2 randomUnitaryMatrix(std::mt19937& rng) { + std::uniform_real_distribution dist(-std::numbers::pi, std::numbers::pi); + const Matrix2x2 su2 = + rzMatrix(dist(rng)) * ryMatrix(dist(rng)) * rzMatrix(dist(rng)); + const Complex globalPhase = std::polar(1.0, dist(rng)); + return Matrix2x2::fromElements( + globalPhase * su2(0, 0), globalPhase * su2(0, 1), globalPhase * su2(1, 0), + globalPhase * su2(1, 1)); +} + +template +[[nodiscard]] static Matrix2x2 rotationMatrix(MLIRContext* ctx, + const double theta) { + OpBuilder builder(ctx); + auto mlirModule = ModuleOp::create(UnknownLoc::get(ctx)); + builder.setInsertionPointToStart(mlirModule.getBody()); + const Location loc = mlirModule.getLoc(); + Value q = AllocOp::create(builder, loc).getResult(); + auto op = RotationOp::create(builder, loc, q, theta); + const auto matrix = op.getUnitaryMatrix(); + if (!matrix) { + ADD_FAILURE() << "Expected constant unitary matrix"; + return Matrix2x2::identity(); + } + return *matrix; +} + +template static void forEachBasis(Fn fn) { + const std::array bases = {"zyz", "zxz", "xzx", + "xyx", "u", "zsxx"}; + for (const char* basis : bases) { + fn(StringRef{basis}); + } +} +[[nodiscard]] static WalkResult failMissingUnitaryMatrix(Operation* op, + bool& failed) { + ADD_FAILURE() << "Expected constant unitary matrix for op: " + << op->getName().getStringRef().str(); + failed = true; + return WalkResult::interrupt(); +} + +[[nodiscard]] static WalkResult +accumulateConstantSingleQubit(UnitaryOpInterface unitary, Operation* op, + Matrix2x2& acc, bool& failed) { + if (Matrix2x2 matrix; unitary.getUnitaryMatrix2x2(matrix)) { + acc = matrix * acc; + return WalkResult::advance(); + } + return failMissingUnitaryMatrix(op, failed); +} + +static WalkResult visit1QUnitaryOp(Operation* op, Matrix2x2& acc, + std::complex& global, bool& failed) { + if (isa(*op)) { + return WalkResult::advance(); + } + if (auto gphase = dyn_cast(*op)) { + if (auto matrix = gphase.getUnitaryMatrix()) { + global *= (*matrix)(0, 0); + } + return WalkResult::advance(); + } + auto unitary = dyn_cast(*op); + if (!unitary) { + return WalkResult::advance(); + } + if (isa(*op)) { + if (!unitary.isSingleQubit()) { + return WalkResult::skip(); + } + const WalkResult result = + accumulateConstantSingleQubit(unitary, op, acc, failed); + return failed ? result : WalkResult::skip(); + } + if (unitary.isTwoQubit()) { + return WalkResult::advance(); + } + const WalkResult result = + accumulateConstantSingleQubit(unitary, op, acc, failed); + return failed ? result : WalkResult::advance(); +} +template +static Matrix2x2 compute1QUnitaryMatrix(WalkRange& range) { + Matrix2x2 acc = Matrix2x2::identity(); + std::complex global{1.0, 0.0}; + bool failed = false; + + range.template walk( + [&acc, &global, &failed](Operation* op) { + return visit1QUnitaryOp(op, acc, global, failed); + }); + + if (failed) { + return Matrix2x2::fromElements(0, 0, 0, 0); + } + return acc * global; +} +static void expectMatrixPreserved(func::FuncOp funcOp, + const Matrix2x2& original, + StringRef label = {}) { + // Logging of the matrices + auto printMatrix = [](const Matrix2x2& matrix) { + std::ostringstream oss; + oss.precision(4); + oss << std::fixed << "[[" << matrix(0, 0) << ", " << matrix(0, 1) << "],\n" + << " [" << matrix(1, 0) << ", " << matrix(1, 1) << "]]"; + return oss.str(); + }; + const auto printOriginal = printMatrix(original); + const auto actual = compute1QUnitaryMatrix(funcOp.getBody()); + const auto printActual = printMatrix(actual); + EXPECT_TRUE(actual.isApprox(original)) + << "Matrix not preserved for " << label.str() << ":\nOriginal:\n" + << printOriginal << "\nActual:\n" + << printActual; +} +template +[[nodiscard]] static std::size_t countOps(func::FuncOp funcOp) { + std::size_t count = 0; + funcOp.walk([&count](OpTy) { ++count; }); + return count; +} + +[[nodiscard]] static std::size_t countZYZGates(func::FuncOp funcOp) { + return countOps(funcOp) + countOps(funcOp); +} + +[[nodiscard]] static std::size_t countZSXXGates(func::FuncOp funcOp) { + return countOps(funcOp) + countOps(funcOp) + + countOps(funcOp); +} + +[[nodiscard]] static std::size_t countBasisGates(func::FuncOp funcOp, + EulerBasis basis) { + switch (basis) { + case ZYZ: + return countZYZGates(funcOp); + case ZXZ: + return countOps(funcOp) + countOps(funcOp); + case XZX: + return countOps(funcOp) + countOps(funcOp); + case XYX: + return countOps(funcOp) + countOps(funcOp); + case U: + return countOps(funcOp); + case ZSXX: + return countZSXXGates(funcOp); + } + return 0; +} + +[[nodiscard]] static SynthesizedCircuit +synthesizeMatrix(MLIRContext* ctx, const Matrix2x2& matrix, EulerBasis basis) { + OwningOpRef mlirModule = ModuleOp::create(UnknownLoc::get(ctx)); + OpBuilder builder(ctx); + builder.setInsertionPointToStart(mlirModule->getBody()); + + auto qubitTy = QubitType::get(ctx); + auto funcTy = builder.getFunctionType({qubitTy}, {qubitTy}); + const Location loc = mlirModule->getLoc(); + auto func = func::FuncOp::create(builder, loc, "main", funcTy); + auto* entry = func.addEntryBlock(); + + builder.setInsertionPointToStart(entry); + Value q = entry->getArgument(0); + const std::optional qubitOut = + synthesizeUnitary1QEuler(builder, loc, q, matrix, 0, true, basis); + if (!qubitOut) { + llvm::report_fatal_error( + "synthesizeUnitary1QEuler failed during test synthesis"); + } + func::ReturnOp::create(builder, loc, *qubitOut); + return SynthesizedCircuit{.mlirModule = std::move(mlirModule), .func = func}; +} + +[[nodiscard]] static std::size_t expectedGateCount(MLIRContext* ctx, + const Matrix2x2& segment, + EulerBasis basis) { + return countBasisGates(synthesizeMatrix(ctx, segment, basis).func, basis); +} + +static void checkSynthesizedReferenceExtras(MLIRContext* ctx, + func::FuncOp funcOp, + EulerBasis basis, + const Matrix2x2& matrix) { + if (basis == U) { + EXPECT_EQ(countOps(funcOp), expectedGateCount(ctx, matrix, basis)); + } + if (!matrix.isApprox(Matrix2x2::identity())) { + return; + } + if (basis == ZYZ) { + EXPECT_EQ(countZYZGates(funcOp), 0U); + } + if (basis == U) { + EXPECT_EQ(countOps(funcOp), 0U); + } +} + +template +static void expectSynthesizedMatrix(MLIRContext* ctx, const Matrix2x2& matrix, + EulerBasis basis, + ExtraChecksT extraChecks) { + const auto circuit = synthesizeMatrix(ctx, matrix, basis); + ASSERT_TRUE(succeeded(verify(*circuit.mlirModule))); + extraChecks(circuit.func, matrix); + expectMatrixPreserved(circuit.func, matrix, "synthesis"); +} + +//===----------------------------------------------------------------------===// +// Euler synthesis tests +//===----------------------------------------------------------------------===// + +TEST_P(ZSXXShortcutTest, SynthesisMatchesGateCount) { + TestFixture fx; + fx.setUp(); + const auto& testCase = GetParam(); + const Matrix2x2 matrix = testCase.makeMatrix(fx.ctx()); + + expectSynthesizedMatrix( + fx.ctx(), matrix, ZSXX, + [&testCase, &fx](func::FuncOp funcOp, const Matrix2x2& original) { + EXPECT_EQ(countOps(funcOp), testCase.expectedRZ); + EXPECT_EQ(countOps(funcOp), testCase.expectedSX); + EXPECT_EQ(countOps(funcOp), testCase.expectedX); + EXPECT_EQ(countZSXXGates(funcOp), + expectedGateCount(fx.ctx(), original, ZSXX)); + }); +} + +INSTANTIATE_TEST_SUITE_P( + ZSXXShortcuts, ZSXXShortcutTest, + testing::Values( + ZSXXShortcutCase{ + "Identity", + [](MLIRContext*) -> Matrix2x2 { return Matrix2x2::identity(); }, 0, + 0, 0}, + ZSXXShortcutCase{ + "PauliX", + [](MLIRContext*) -> Matrix2x2 { return XOp::getUnitaryMatrix(); }, + 0, 0, 1}, + ZSXXShortcutCase{"PureZ", + [](MLIRContext*) -> Matrix2x2 { + return rzMatrix(0.3) * rzMatrix(0.7); + }, + 1, 0, 0}, + ZSXXShortcutCase{"ZYZNearZeroTheta", + [](MLIRContext*) -> Matrix2x2 { + constexpr double tol = 0.5 * mlir::utils::TOLERANCE; + return rzMatrix(0.4) * ryMatrix(tol) * rzMatrix(0.3); + }, + 1, 0, 0}, + ZSXXShortcutCase{"RYHalfPi", + [](MLIRContext* ctx) -> Matrix2x2 { + return rotationMatrix(ctx, + std::numbers::pi / 2.0); + }, + 2, 1, 0}, + ZSXXShortcutCase{"RYNearHalfPi", + [](MLIRContext* ctx) -> Matrix2x2 { + return rotationMatrix( + ctx, (std::numbers::pi / 2.0) + + (0.5 * mlir::utils::TOLERANCE)); + }, + 2, 1, 0}, + ZSXXShortcutCase{"RYNearZero", + [](MLIRContext* ctx) -> Matrix2x2 { + return rotationMatrix( + ctx, 0.5 * mlir::utils::TOLERANCE); + }, + 0, 0, 0}, + ZSXXShortcutCase{"RYNearPi", + [](MLIRContext* ctx) -> Matrix2x2 { + return rotationMatrix( + ctx, std::numbers::pi - + (0.5 * mlir::utils::TOLERANCE)); + }, + 1, 0, 1}), + [](const testing::TestParamInfo& info) { + return std::string(info.param.label); + }); + +TEST_P(EulerSynthesisExactTest, ReconstructsReferenceMatrices) { + TestFixture fx; + fx.setUp(); + const auto [basis, matrixFn] = GetParam(); + const Matrix2x2 original = matrixFn(fx.ctx()); + expectSynthesizedMatrix( + fx.ctx(), original, basis, + [&fx, basis](func::FuncOp funcOp, const Matrix2x2& matrix) { + checkSynthesizedReferenceExtras(fx.ctx(), funcOp, basis, matrix); + }); +} + +INSTANTIATE_TEST_SUITE_P( + SingleQubitMatrices, EulerSynthesisExactTest, + testing::Combine(testing::Values(ZYZ, ZXZ, XZX, XYX, U, ZSXX), + testing::Values( + [](MLIRContext* /*ctx*/) -> Matrix2x2 { + return Matrix2x2::identity(); + }, + [](MLIRContext* ctx) -> Matrix2x2 { + return rotationMatrix(ctx, 2.0); + }, + [](MLIRContext* ctx) -> Matrix2x2 { + return rotationMatrix(ctx, + std::numbers::pi / 2.0); + }, + [](MLIRContext* ctx) -> Matrix2x2 { + return rotationMatrix(ctx, 0.5); + }, + [](MLIRContext* ctx) -> Matrix2x2 { + return rotationMatrix(ctx, 3.14); + }, + [](MLIRContext* /*ctx*/) -> Matrix2x2 { + return HOp::getUnitaryMatrix(); + }))); + +TEST(EulerSynthesisTest, RandomReconstructionAllBases) { + TestFixture fx; + fx.setUp(); + std::mt19937 rng{12345678UL}; + + for (int i = 0; i < 200; ++i) { + const auto original = randomUnitaryMatrix(rng); + forEachBasis([&fx, &original](StringRef basisStr) { + const auto parsed = parseEulerBasis(basisStr); + ASSERT_TRUE(parsed) << "basis=" << basisStr.str(); + const auto circuit = synthesizeMatrix(fx.ctx(), original, *parsed); + ASSERT_TRUE(succeeded(verify(*circuit.mlirModule))) + << "basis=" << basisStr.str(); + expectMatrixPreserved(circuit.func, original, basisStr); + }); + } +} + +//===----------------------------------------------------------------------===// +// FuseSingleQubitUnitaryRuns support +//===----------------------------------------------------------------------===// + +[[nodiscard]] static bool isAllowedBasisGate(const Operation& op, + EulerBasis basis) { + switch (basis) { + case ZYZ: + return isa(op); + case ZXZ: + return isa(op); + case XZX: + return isa(op); + case XYX: + return isa(op); + case U: + return isa(op); + case ZSXX: + return isa(op); + } + return false; +} + +template [[nodiscard]] static bool inParent(Operation* op) { + return op != nullptr && op->getParentOfType() != nullptr; +} + +static WalkResult visitBasisGateOp(Operation* op, StringRef basis, + EulerBasis parsedBasis) { + if (isa(*op)) { + return WalkResult::advance(); + } + if (auto unitary = dyn_cast(*op)) { + if (unitary.isTwoQubit() || isa(*op)) { + return unitary.isTwoQubit() ? WalkResult::advance() : WalkResult::skip(); + } + if (Matrix2x2 matrix; unitary.getUnitaryMatrix2x2(matrix)) { + EXPECT_TRUE(isAllowedBasisGate(*op, parsedBasis) || isa(*op)) + << "basis=" << basis.str() + << " unexpected gate: " << op->getName().getStringRef().str(); + return WalkResult::advance(); + } + ADD_FAILURE() << "basis=" << basis.str() << " missing constant matrix for: " + << op->getName().getStringRef().str(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); +} + +static void skipBeforeFuse(func::FuncOp /*funcOp*/, + const Matrix2x2& /*original*/) { + // Pre-fuse checks are not required for this scenario. +} + +template +[[nodiscard]] static Matrix2x2 matrixInParent(func::FuncOp funcOp) { + auto parents = funcOp.getOps(); + if (parents.begin() == parents.end()) { + ADD_FAILURE() << "Expected parent op in function"; + return Matrix2x2::fromElements(0, 0, 0, 0); + } + return compute1QUnitaryMatrix((*parents.begin()).getRegion()); +} + +static void expectBasisGatesOnly(func::FuncOp funcOp, StringRef basis) { + const auto parsed = parseEulerBasis(basis); + ASSERT_TRUE(parsed) << basis.str(); + + funcOp.walk( + [basis, parsedBasis = *parsed](Operation* op) { + return visitBasisGateOp(op, basis, parsedBasis); + }); +} + +static void expectFusePreserved(func::FuncOp funcOp, const Matrix2x2& original, + StringRef basis) { + expectMatrixPreserved(funcOp, original, basis); + expectBasisGatesOnly(funcOp, basis); +} +[[nodiscard]] static Matrix2x2 splitFixtureHTSegmentMatrix() { + return TOp::getUnitaryMatrix() * HOp::getUnitaryMatrix(); +} + +[[nodiscard]] static Matrix2x2 splitFixtureRZSXSegmentMatrix() { + return SXOp::getUnitaryMatrix() * rzMatrix(0.321); +} + +[[nodiscard]] static Matrix2x2 overlongZSXXPureZRunMatrix() { + return SXOp::getUnitaryMatrix() * rzMatrix(std::numbers::pi) * + SXOp::getUnitaryMatrix(); +} +template +[[nodiscard]] static std::size_t countInParent(func::FuncOp funcOp) { + std::size_t count = 0; + funcOp.walk([&count](OpTy op) { + if (inParent(op.getOperation())) { + ++count; + } + }); + return count; +} +static void expectSplitFixtureSegments(func::FuncOp funcOp, StringRef basis, + MLIRContext* ctx) { + const auto parsed = parseEulerBasis(basis); + ASSERT_TRUE(parsed) << basis.str(); + const std::size_t ht = + expectedGateCount(ctx, splitFixtureHTSegmentMatrix(), *parsed); + const std::size_t rzsx = + expectedGateCount(ctx, splitFixtureRZSXSegmentMatrix(), *parsed); + + std::size_t outside = 0; + std::size_t inside = 0; + funcOp.walk([&outside, &inside](Operation* op) { + if (isa(*op)) { + return; + } + auto unitary = dyn_cast(op); + if (Matrix2x2 matrix; unitary && unitary.isSingleQubit() && + unitary.getUnitaryMatrix2x2(matrix)) { + if (inParent(op)) { + ++inside; + } else { + ++outside; + } + } + }); + EXPECT_EQ(outside, ht) << "basis=" << basis.str(); + EXPECT_EQ(inside, rzsx) << "basis=" << basis.str(); +} + +template +static void expectSplitFixtureSegments(func::FuncOp funcOp, StringRef basis, + MLIRContext* ctx, + BoundaryPred isBoundary) { + const auto parsed = parseEulerBasis(basis); + ASSERT_TRUE(parsed) << basis.str(); + const std::size_t ht = + expectedGateCount(ctx, splitFixtureHTSegmentMatrix(), *parsed); + const std::size_t rzsx = + expectedGateCount(ctx, splitFixtureRZSXSegmentMatrix(), *parsed); + + std::size_t before = 0; + std::size_t after = 0; + bool seenBoundary = false; + for (Operation& op : funcOp.getBody().front().without_terminator()) { + if (!seenBoundary && isBoundary(op)) { + seenBoundary = true; + continue; + } + if (isa(op)) { + continue; + } + auto unitary = dyn_cast(op); + if (Matrix2x2 matrix; unitary && unitary.isSingleQubit() && + unitary.getUnitaryMatrix2x2(matrix)) { + if (seenBoundary) { + ++after; + } else { + ++before; + } + } + } + EXPECT_EQ(before, ht) << "basis=" << basis.str(); + EXPECT_EQ(after, rzsx) << "basis=" << basis.str(); +} + +static LogicalResult runFuse(ModuleOp mlirModule, StringRef basis) { + PassManager pm(mlirModule.getContext()); + qco::FuseSingleQubitUnitaryRunsOptions opts; + opts.basis = basis.str(); + pm.addPass(qco::createFuseSingleQubitUnitaryRuns(opts)); + return pm.run(mlirModule); +} + +template +static void runFuseOnProgram(MLIRContext* ctx, ProgramT program, + StringRef basis, BeforeT beforeFuse, + AfterT afterFuse) { + auto owned = QCOProgramBuilder::build(ctx, program); + ASSERT_TRUE(owned); + ModuleOp mlirModule = *owned; + ASSERT_TRUE(succeeded(verify(mlirModule))); + + auto funcOp = mlirModule.lookupSymbol("main"); + ASSERT_TRUE(funcOp); + const Matrix2x2 original = compute1QUnitaryMatrix(funcOp); + beforeFuse(funcOp, original); + + ASSERT_TRUE(succeeded(runFuse(mlirModule, basis))); + ASSERT_TRUE(succeeded(verify(mlirModule))); + + funcOp = mlirModule.lookupSymbol("main"); + ASSERT_TRUE(funcOp); + afterFuse(funcOp, original); +} + +template +static void runFuseForAllBases(MLIRContext* ctx, ProgramT program, + ChecksT checksAfter) { + forEachBasis([&ctx, program, &checksAfter](StringRef basis) { + runFuseOnProgram( + ctx, program, basis, skipBeforeFuse, + [basis, &checksAfter](func::FuncOp funcOp, const Matrix2x2& original) { + checksAfter(funcOp, basis, original); + }); + }); +} + +template +static void runFuseInParent(MLIRContext* ctx, ProgramT program, + BeforeT checkBefore, AfterT checkAfter) { + Matrix2x2 bodyBefore; + runFuseOnProgram( + ctx, program, "u", + [&checkBefore, &bodyBefore](func::FuncOp funcOp, const Matrix2x2&) { + checkBefore(funcOp); + bodyBefore = matrixInParent(funcOp); + }, + [&checkAfter, &bodyBefore](func::FuncOp funcOp, const Matrix2x2&) { + checkAfter(funcOp); + EXPECT_TRUE(matrixInParent(funcOp).isApprox( + bodyBefore, MATRIX_TOLERANCE)); + }); +} + +// --- Fuse program fixtures --- // + +static void singleQubitRunWithSingleQubitGate(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + q[0] = b.h(q[0]); + q[0] = b.t(q[0]); + q[0] = b.rz(0.123, q[0]); + q[0] = b.inv({q[0]}, [&b](ValueRange targets) -> SmallVector { + return {b.sx(targets[0])}; + })[0]; + q[0] = b.ry(-0.456, q[0]); +} + +static void singleQubitRunsSplitByTwoQGate(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + q[0] = b.h(q[0]); + q[0] = b.t(q[0]); + std::tie(q[0], q[1]) = b.swap(q[0], q[1]); + q[0] = b.rz(0.321, q[0]); + q[0] = b.sx(q[0]); +} + +static void singleQubitRunsSplitByBarrier(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + q[0] = b.h(q[0]); + q[0] = b.t(q[0]); + q[0] = b.barrier({q[0]})[0]; + q[0] = b.rz(0.321, q[0]); + q[0] = b.sx(q[0]); +} + +static void singleNonBasisGate(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + q[0] = b.h(q[0]); +} + +static void singlePauliX(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + q[0] = b.x(q[0]); +} + +static void canonicalZYZRun(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + q[0] = b.rz(0.3, q[0]); + q[0] = b.ry(0.5, q[0]); + q[0] = b.rz(0.7, q[0]); +} + +static void overlongZYZRun(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + q[0] = b.rz(0.3, q[0]); + q[0] = b.ry(0.5, q[0]); + q[0] = b.rz(0.7, q[0]); + q[0] = b.ry(0.9, q[0]); + q[0] = b.rz(1.1, q[0]); + q[0] = b.ry(1.3, q[0]); +} + +static void overlongZSXXMixedPureZRun(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + q[0] = b.sx(q[0]); + q[0] = b.rz(std::numbers::pi, q[0]); + q[0] = b.sx(q[0]); +} + +static void singleQubitRunInScfFor(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + b.scfFor(0, 1, 1, ValueRange{q[0]}, [&b](Value, ValueRange iterArgs) { + Value wire = iterArgs[0]; + wire = b.h(wire); + wire = b.t(wire); + wire = b.rz(0.123, wire); + return SmallVector{wire}; + }); +} + +static void xInverseTwoX(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + q[0] = b.x(q[0]); + q[0] = b.inv({q[0]}, [&b](ValueRange targets) { + Value wire = b.x(targets[0]); + wire = b.x(wire); + return SmallVector{wire}; + })[0]; + q[0] = b.x(q[0]); +} + +static void inverseMultiQubitBodySingleQubitRun(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + auto outs = + b.inv({q[0], q[1]}, [&b](ValueRange targets) -> SmallVector { + Value wire = b.h(targets[0]); + wire = b.t(wire); + return {wire, targets[1]}; + }); + q[0] = outs[0]; + q[1] = outs[1]; +} + +static void controlledInverseHT(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.ctrl(q[0], q[1], [&b](ValueRange targets) { + auto wire = b.inv({targets[0]}, [&b](ValueRange innerTargets) { + auto inner = b.h(innerTargets[0]); + inner = b.t(inner); + return SmallVector{inner}; + })[0]; + return SmallVector{wire}; + }); +} + +static void controlledH(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.ctrl(q[0], q[1], + [&b](ValueRange targets) { return SmallVector{b.h(targets[0])}; }); +} + +static void singleQubitRunsSplitByScfFor(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + q[0] = b.h(q[0]); + q[0] = b.t(q[0]); + b.scfFor(0, 1, 1, ValueRange{q[0]}, [&b](Value, ValueRange iterArgs) { + Value wire = iterArgs[0]; + wire = b.rz(0.321, wire); + wire = b.sx(wire); + return SmallVector{wire}; + }); +} + +//===----------------------------------------------------------------------===// +// FuseSingleQubitUnitaryRuns tests +//===----------------------------------------------------------------------===// + +TEST(FuseSingleQubitUnitaryRunsTest, InvalidBasisFailsPass) { + TestFixture fx; + fx.setUp(); + auto owned = + QCOProgramBuilder::build(fx.ctx(), &singleQubitRunWithSingleQubitGate); + ASSERT_TRUE(owned); + EXPECT_TRUE(failed(runFuse(*owned, "not-a-basis"))); +} + +TEST(FuseSingleQubitUnitaryRunsTest, FusesProgramsAllBases) { + TestFixture fx; + fx.setUp(); + + struct Case { + void (*program)(QCOProgramBuilder&); + void (*extra)(func::FuncOp, StringRef); + }; + const std::array cases = {{ + {.program = &singleQubitRunWithSingleQubitGate, + .extra = + [](func::FuncOp funcOp, StringRef basis) { + EXPECT_EQ(countOps(funcOp), 0U) << basis.str(); + }}, + {.program = &singleNonBasisGate, + .extra = + [](func::FuncOp funcOp, StringRef basis) { + EXPECT_EQ(countOps(funcOp), 0U) << basis.str(); + }}, + }}; + + for (const Case& testCase : cases) { + runFuseForAllBases(fx.ctx(), testCase.program, + [&testCase](func::FuncOp funcOp, StringRef basis, + const Matrix2x2& original) { + testCase.extra(funcOp, basis); + expectFusePreserved(funcOp, original, basis); + }); + } +} + +TEST(FuseSingleQubitUnitaryRunsTest, FusesOverlongInBasisRun) { + TestFixture fx; + fx.setUp(); + runFuseOnProgram( + fx.ctx(), &overlongZYZRun, "zyz", + [](func::FuncOp funcOp, const Matrix2x2&) { + ASSERT_EQ(countZYZGates(funcOp), 6U); + }, + [&fx](func::FuncOp funcOp, const Matrix2x2& original) { + EXPECT_EQ(countZYZGates(funcOp), + expectedGateCount(fx.ctx(), original, ZYZ)); + expectFusePreserved(funcOp, original, "zyz"); + }); +} + +TEST(FuseSingleQubitUnitaryRunsTest, DoesNotFuseCanonicalInBasisRun) { + TestFixture fx; + fx.setUp(); + + runFuseOnProgram(fx.ctx(), &singlePauliX, "zsxx", skipBeforeFuse, + [](func::FuncOp funcOp, const Matrix2x2& original) { + EXPECT_EQ(countOps(funcOp), 1U); + expectFusePreserved(funcOp, original, "zsxx"); + }); + + runFuseOnProgram(fx.ctx(), &canonicalZYZRun, "zyz", skipBeforeFuse, + [](func::FuncOp funcOp, const Matrix2x2& original) { + EXPECT_EQ(countZYZGates(funcOp), 3U); + expectFusePreserved(funcOp, original, "zyz"); + }); +} + +TEST(FuseSingleQubitUnitaryRunsTest, + FusesOverlongZSXXMixedRunComposingToPureZ) { + TestFixture fx; + fx.setUp(); + runFuseOnProgram( + fx.ctx(), &overlongZSXXMixedPureZRun, "zsxx", + [](func::FuncOp funcOp, const Matrix2x2&) { + ASSERT_EQ(countZSXXGates(funcOp), 3U); + }, + [&fx](func::FuncOp funcOp, const Matrix2x2& original) { + EXPECT_EQ( + countZSXXGates(funcOp), + expectedGateCount(fx.ctx(), overlongZSXXPureZRunMatrix(), ZSXX)); + expectFusePreserved(funcOp, original, "zsxx"); + }); +} + +TEST(FuseSingleQubitUnitaryRunsTest, DoesNotFuseAcrossBoundariesAllBases) { + TestFixture fx; + fx.setUp(); + + struct Case { + void (*program)(QCOProgramBuilder&); + void (*check)(func::FuncOp, StringRef, MLIRContext*); + }; + const std::array cases = {{ + {.program = &singleQubitRunsSplitByTwoQGate, + .check = + [](func::FuncOp funcOp, StringRef basis, MLIRContext* ctx) { + std::size_t twoQ = 0; + funcOp.walk([&twoQ](UnitaryOpInterface op) { + if (op.isTwoQubit()) { + ++twoQ; + } + }); + EXPECT_EQ(twoQ, 1U) << basis.str(); + expectSplitFixtureSegments( + funcOp, basis, ctx, [](const Operation& op) { + if (auto unitary = dyn_cast(op)) { + return unitary.isTwoQubit(); + } + return false; + }); + }}, + {.program = &singleQubitRunsSplitByBarrier, + .check = + [](func::FuncOp funcOp, StringRef basis, MLIRContext* ctx) { + EXPECT_EQ(countOps(funcOp), 1U) << basis.str(); + expectSplitFixtureSegments( + funcOp, basis, ctx, + [](const Operation& op) { return isa(op); }); + }}, + {.program = &singleQubitRunsSplitByScfFor, + .check = + [](func::FuncOp funcOp, StringRef basis, MLIRContext* ctx) { + EXPECT_EQ(countOps(funcOp), 1U) << basis.str(); + expectSplitFixtureSegments(funcOp, basis, ctx); + }}, + }}; + + for (const Case& testCase : cases) { + runFuseForAllBases(fx.ctx(), testCase.program, + [&testCase, &fx](func::FuncOp funcOp, StringRef basis, + const Matrix2x2& original) { + testCase.check(funcOp, basis, fx.ctx()); + expectFusePreserved(funcOp, original, basis); + }); + } +} + +TEST(FuseSingleQubitUnitaryRunsTest, EliminatesIdentityInvMultiOpBody) { + TestFixture fx; + fx.setUp(); + runFuseOnProgram( + fx.ctx(), xInverseTwoX, "u", + [](func::FuncOp funcOp, const Matrix2x2&) { + EXPECT_EQ(countOps(funcOp), 4U); + EXPECT_EQ(countOps(funcOp), 1U); + }, + [&fx](func::FuncOp funcOp, const Matrix2x2& original) { + EXPECT_EQ(countOps(funcOp), 0U); + EXPECT_EQ(countOps(funcOp), 0U); + EXPECT_EQ(countOps(funcOp), + expectedGateCount(fx.ctx(), original, U)); + expectMatrixPreserved(funcOp, original, "x-inv-xx-x"); + }); +} + +TEST(FuseSingleQubitUnitaryRunsTest, FusesRunInMultiQubitInvBody) { + TestFixture fx; + fx.setUp(); + runFuseInParent( + fx.ctx(), inverseMultiQubitBodySingleQubitRun, + [](func::FuncOp funcOp) { + EXPECT_EQ(countOps(funcOp), 1U); + EXPECT_EQ((countInParent(funcOp)), 1U); + EXPECT_EQ((countInParent(funcOp)), 1U); + EXPECT_EQ((countInParent(funcOp)), 0U); + }, + [](func::FuncOp funcOp) { + EXPECT_EQ(countOps(funcOp), 1U); + EXPECT_EQ((countInParent(funcOp)), 1U); + EXPECT_EQ((countInParent(funcOp)), 0U); + EXPECT_EQ((countInParent(funcOp)), 0U); + }); +} + +TEST(FuseSingleQubitUnitaryRunsTest, FusesInCtrlBody) { + TestFixture fx; + fx.setUp(); + + runFuseInParent( + fx.ctx(), controlledH, + [](func::FuncOp funcOp) { + EXPECT_EQ((countInParent(funcOp)), 1U); + EXPECT_EQ((countInParent(funcOp)), 0U); + }, + [](func::FuncOp funcOp) { + EXPECT_EQ((countInParent(funcOp)), 1U); + EXPECT_EQ((countInParent(funcOp)), 0U); + }); + + runFuseInParent( + fx.ctx(), controlledInverseHT, + [](func::FuncOp funcOp) { + EXPECT_EQ((countInParent(funcOp)), 1U); + EXPECT_EQ((countInParent(funcOp)), 1U); + EXPECT_EQ((countInParent(funcOp)), 1U); + EXPECT_EQ((countInParent(funcOp)), 0U); + }, + [](func::FuncOp funcOp) { + EXPECT_EQ((countInParent(funcOp)), 0U); + EXPECT_EQ((countInParent(funcOp)), 1U); + EXPECT_EQ((countInParent(funcOp)), 0U); + EXPECT_EQ((countInParent(funcOp)), 0U); + }); +} + +TEST(FuseSingleQubitUnitaryRunsTest, FusesRunInScfForBody) { + TestFixture fx; + fx.setUp(); + runFuseInParent( + fx.ctx(), &singleQubitRunInScfFor, + [](func::FuncOp funcOp) { + EXPECT_EQ((countInParent(funcOp)), 1U); + EXPECT_EQ((countInParent(funcOp)), 1U); + EXPECT_EQ((countInParent(funcOp)), 1U); + EXPECT_EQ((countInParent(funcOp)), 0U); + }, + [](func::FuncOp funcOp) { + EXPECT_EQ((countInParent(funcOp)), 1U); + EXPECT_EQ((countInParent(funcOp)), 0U); + EXPECT_EQ((countInParent(funcOp)), 0U); + EXPECT_EQ((countInParent(funcOp)), 0U); + }); +} diff --git a/mlir/unittests/Dialect/QCO/Utils/test_unitary_matrix.cpp b/mlir/unittests/Dialect/QCO/Utils/test_unitary_matrix.cpp index 8008dc7561..afa0792415 100644 --- a/mlir/unittests/Dialect/QCO/Utils/test_unitary_matrix.cpp +++ b/mlir/unittests/Dialect/QCO/Utils/test_unitary_matrix.cpp @@ -12,6 +12,7 @@ #include +#include #include #include @@ -53,6 +54,16 @@ TEST(UnitaryMatrix1x1, IsApprox) { EXPECT_TRUE(a.isApprox(b)); EXPECT_FALSE(a.isApprox(Matrix1x1::fromElements(2.0))); EXPECT_TRUE(a.isApprox(Matrix1x1::fromElements(1.1), 0.2)); + EXPECT_EQ((Matrix1x1::fromElements(0.5) * 2.0)(0, 0), 1.0); + Matrix1x1 scaled = Matrix1x1::fromElements(0.5); + scaled *= 2.0; + EXPECT_EQ(scaled(0, 0), 1.0); +} + +TEST(UnitaryMatrix1x1, Adjoint) { + const Matrix1x1 phase = Matrix1x1::fromElements(Complex{0.25, 0.5}); + EXPECT_TRUE(phase.adjoint().isApprox( + Matrix1x1::fromElements(std::conj(phase.value)))); } TEST(UnitaryMatrix2x2, IdentityAndAccess) { @@ -70,6 +81,12 @@ TEST(UnitaryMatrix2x2, MultiplyAdjointTraceDeterminant) { EXPECT_TRUE((x * x).isApprox(identity)); EXPECT_TRUE((identity * x).isApprox(x)); + EXPECT_TRUE((x * std::exp(1i * 0.5)) + .isApprox(Matrix2x2::fromElements(0, std::exp(1i * 0.5), + std::exp(1i * 0.5), 0))); + Matrix2x2 scaled = x; + scaled *= std::exp(1i * 0.5); + EXPECT_TRUE(scaled.isApprox(x * std::exp(1i * 0.5))); EXPECT_TRUE(x.adjoint().isApprox(x)); EXPECT_EQ(x.trace(), Complex(0.0, 0.0)); EXPECT_EQ(identity.trace(), Complex(2.0, 0.0)); @@ -77,6 +94,15 @@ TEST(UnitaryMatrix2x2, MultiplyAdjointTraceDeterminant) { EXPECT_EQ(identity.determinant(), Complex(1.0, 0.0)); } +TEST(UnitaryMatrix2x2, PremultiplyBy) { + const Matrix2x2 x = pauliX(); + const Matrix2x2 y = Matrix2x2::fromElements(1, 0, 0, std::exp(1i * 0.5)); + Matrix2x2 acc = Matrix2x2::identity(); + acc.premultiplyBy(x); + acc.premultiplyBy(y); + EXPECT_TRUE(acc.isApprox(y * x)); +} + TEST(UnitaryMatrix2x2, IsApprox) { const Matrix2x2 a = Matrix2x2::identity(); Matrix2x2 b = a; @@ -90,6 +116,7 @@ TEST(UnitaryMatrix4x4, IdentityAndAccess) { EXPECT_TRUE(identity.isApprox( Matrix4x4::fromElements(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1))); EXPECT_EQ(identity(2, 2), 1.0); + EXPECT_TRUE((swapMatrix() * 2.0)(0, 0) == 2.0); } TEST(UnitaryMatrix4x4, MultiplyAdjointTraceDeterminant) { @@ -98,10 +125,22 @@ TEST(UnitaryMatrix4x4, MultiplyAdjointTraceDeterminant) { EXPECT_TRUE((swap * swap).isApprox(identity)); EXPECT_TRUE(swap.adjoint().isApprox(swap)); + Matrix4x4 scaled = swap; + scaled *= 2.0; + EXPECT_TRUE(scaled.isApprox(swap * 2.0)); EXPECT_EQ(identity.trace(), Complex(4.0, 0.0)); EXPECT_EQ(identity.determinant(), Complex(1.0, 0.0)); } +TEST(UnitaryMatrix4x4, PremultiplyBy) { + const Matrix4x4 swap = swapMatrix(); + const Matrix4x4 phase = Matrix4x4::identity() * std::exp(1i * 0.25); + Matrix4x4 acc = Matrix4x4::identity(); + acc.premultiplyBy(swap); + acc.premultiplyBy(phase); + EXPECT_TRUE(acc.isApprox(phase * swap)); +} + TEST(UnitaryMatrix4x4, IsApprox) { const Matrix4x4 a = Matrix4x4::identity(); Matrix4x4 b = a; @@ -153,6 +192,16 @@ TEST(DynamicMatrix, IdentityAndElementAccess) { EXPECT_EQ(mutableMatrix(1, 1), 0.5); } +TEST(DynamicMatrix, FromAdjoint) { + const Matrix2x2 x = pauliX(); + EXPECT_TRUE(DynamicMatrix::fromAdjoint(x).isApprox(x.adjoint())); + const Complex global = std::polar(1.0, 0.25); + EXPECT_TRUE( + DynamicMatrix::fromAdjoint(x * global).isApprox((x * global).adjoint())); + EXPECT_TRUE(DynamicMatrix(x).isApprox(x)); + EXPECT_TRUE(DynamicMatrix(swapMatrix()).isApprox(swapMatrix())); +} + TEST(DynamicMatrix, AssignFrom) { DynamicMatrix dynamic; @@ -218,10 +267,52 @@ TEST(DynamicMatrix, IsApproxRejectsMismatchedExtents) { EXPECT_FALSE(DynamicMatrix::identity(1).isApprox(DynamicMatrix::identity(2))); } +TEST(Matrix1x1, AssignFromDynamicMatrix) { + const Matrix1x1 phase = Matrix1x1::fromElements(Complex{0.25, 0.5}); + + DynamicMatrix dynamic; + dynamic.assignFrom(phase); + + Matrix1x1 out = Matrix1x1::fromElements(1.0); + EXPECT_TRUE(out.assignFrom(dynamic)); + EXPECT_TRUE(out.isApprox(phase)); + EXPECT_FALSE(out.assignFrom(DynamicMatrix::identity(2))); +} + +TEST(Matrix2x2, AssignFromDynamicMatrix) { + const Matrix2x2 x = pauliX(); + + DynamicMatrix dynamic; + dynamic.assignFrom(x); + + Matrix2x2 out = Matrix2x2::identity(); + EXPECT_TRUE(out.assignFrom(dynamic)); + EXPECT_TRUE(out.isApprox(x)); + EXPECT_FALSE(out.assignFrom(DynamicMatrix::identity(3))); +} + +TEST(Matrix4x4, AssignFromDynamicMatrix) { + const Matrix4x4 swap = swapMatrix(); + + DynamicMatrix dynamic; + dynamic.assignFrom(swap); + + Matrix4x4 out = Matrix4x4::identity(); + EXPECT_TRUE(out.assignFrom(dynamic)); + EXPECT_TRUE(out.isApprox(swap)); + EXPECT_FALSE(out.assignFrom(DynamicMatrix::identity(2))); +} + TEST(DynamicMatrix, IsApproxOverloads) { + const Matrix1x1 phase = Matrix1x1::fromElements(Complex{0.25, 0.5}); const Matrix2x2 x = pauliX(); const Matrix4x4 swap = swapMatrix(); + DynamicMatrix as1x1; + as1x1.assignFrom(phase); + EXPECT_TRUE(as1x1.isApprox(phase)); + EXPECT_FALSE(as1x1.isApprox(Matrix1x1::fromElements(1.0))); + DynamicMatrix as2x2; as2x2.assignFrom(x); EXPECT_TRUE(as2x2.isApprox(x)); @@ -233,6 +324,7 @@ TEST(DynamicMatrix, IsApproxOverloads) { EXPECT_FALSE(as4x4.isApprox(Matrix4x4::identity())); DynamicMatrix wrongDim = DynamicMatrix::identity(3); + EXPECT_FALSE(wrongDim.isApprox(phase)); EXPECT_FALSE(wrongDim.isApprox(x)); EXPECT_FALSE(wrongDim.isApprox(swap)); diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index 1fadf83eb2..d27b4f5bc0 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -329,6 +329,30 @@ void inverseTwoX(QCOProgramBuilder& b) { }); } +void inverseGphaseX(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + b.inv(q[0], [&](ValueRange qubits) { + b.gphase(-0.123); + return SmallVector{b.x(qubits[0])}; + }); +} + +void inverseGphaseBarrier(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + b.inv(q[0], [&](ValueRange qubits) -> SmallVector { + b.gphase(0.123); + return {b.barrier({qubits[0]})[0]}; + }); +} + +void inverseTwoBarriersInInv(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + b.inv(q[0], [&](ValueRange qubits) -> SmallVector { + auto q0 = b.barrier({qubits[0]})[0]; + return {b.barrier({q0})[0]}; + }); +} + void y(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.y(q[0]); diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index 1a5f5ce229..149753795a 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -177,6 +177,18 @@ void controlledTwoX(QCOProgramBuilder& b); /// gates. void inverseTwoX(QCOProgramBuilder& b); +/// Creates a circuit with an inverse modifier applied to a global phase and an +/// X gate. +void inverseGphaseX(QCOProgramBuilder& b); + +/// Creates a circuit with an inverse modifier applied to a global phase and a +/// barrier. +void inverseGphaseBarrier(QCOProgramBuilder& b); + +/// Creates a circuit with an inverse modifier applied to two consecutive +/// barriers. +void inverseTwoBarriersInInv(QCOProgramBuilder& b); + // --- YOp ------------------------------------------------------------------ // /// Creates a circuit with just a Y gate.