diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index b2bd72e8f..1c414f2fe 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -8,7 +8,7 @@ from enum import Enum from fractions import Fraction from math import pi -from typing import TYPE_CHECKING, SupportsFloat +from typing import TYPE_CHECKING, SupportsComplex, SupportsFloat # `assert_never` introduced in Python 3.11 from typing_extensions import assert_never @@ -26,6 +26,8 @@ from graphix.flow.core import PauliFlow, XZCorrections from graphix.fundamentals import Angle from graphix.pattern import Pattern + from graphix.sim.density_matrix import DensityMatrix + from graphix.sim.statevec import _ENCODING, Statevec class OutputFormat(Enum): @@ -439,3 +441,491 @@ def xzcorr_to_str(xzcorr: XZCorrections[AbstractMeasurement], output: OutputForm partial_order_to_str(xzcorr.partial_order_layers, output), ) ) + + +# --- Complex amplitude and quantum-state pretty-printing --------------------- +# +# The recognition of "nice" real numbers (fractions, square roots) relies on a +# square-then-rationalize trick: a real ``x`` is matched against ``sqrt(p / q)`` +# by approximating ``x ** 2`` with a rational ``p / q``. This single mechanism +# uniformly handles plain fractions (``1/4``), surds (``√2/2``, ``√3/2``) and, +# combined with :func:`angle_to_str`, the phase of exponentials (``e^{iπ/3}``). + +_DEFAULT_MAX_DENOMINATOR = 1000 +_DEFAULT_ATOL = 1e-9 +_DEFAULT_RTOL = 0.0 +_DEFAULT_PRECISION = 4 + + +def _squarefree_decomposition(n: int) -> tuple[int, int]: + """Decompose a non-negative integer as ``outer ** 2 * inner`` with ``inner`` squarefree. + + Parameters + ---------- + n : int + Non-negative integer to decompose. + + Returns + ------- + tuple[int, int] + ``(outer, inner)`` such that ``outer ** 2 * inner == n`` and ``inner`` is + squarefree. ``n == 0`` returns ``(0, 1)``. + """ + if n == 0: + return 0, 1 + outer = 1 + inner = n + d = 2 + while d * d <= inner: + while inner % (d * d) == 0: + inner //= d * d + outer *= d + d += 1 + return outer, inner + + +def _recognize_sqrt(x: float, max_denominator: int, atol: float, rtol: float) -> tuple[int, int, int] | None: + """Recognize a real number as ``signed_num * sqrt(inner) / den``. + + The recognition approximates ``x ** 2`` by a rational ``p / q``; on success, + ``x = ±sqrt(p / q)`` is rewritten with a rationalized, fully-reduced + denominator. Pure rationals are covered as the special case ``inner == 1``. + + Parameters + ---------- + x : float + Real number to recognize. + max_denominator : int + Maximum denominator allowed when approximating ``x ** 2`` by a rational. + atol : float + Absolute tolerance for that rational approximation. + rtol : float + Relative tolerance for that rational approximation. + + Returns + ------- + tuple[int, int, int] or None + ``(signed_num, inner, den)`` with ``den > 0`` and ``inner`` a positive + squarefree integer, encoding ``x = signed_num * sqrt(inner) / den``. + Returns ``None`` when ``x`` is not recognized as such a value. + """ + square = Fraction(x * x).limit_denominator(max_denominator) + if not math.isclose(x * x, float(square), rel_tol=rtol, abs_tol=atol): + return None + num_outer, num_inner = _squarefree_decomposition(square.numerator) + den_outer, den_inner = _squarefree_decomposition(square.denominator) + # x = ±(num_outer √num_inner) / (den_outer √den_inner); rationalize by √den_inner. + combined_outer, inner = _squarefree_decomposition(num_inner * den_inner) + num = num_outer * combined_outer + den = den_outer * den_inner + divisor = math.gcd(num, den) + num //= divisor + den //= divisor + sign = -1 if x < 0 else 1 + return sign * num, inner, den + + +def _imaginary_unit(output: OutputFormat) -> str: + return r"\mathrm{i}" if output == OutputFormat.LaTeX else "i" + + +def _sqrt_str(inner: int, output: OutputFormat) -> str: + """Return the string for ``sqrt(inner)`` (empty when ``inner == 1``).""" + if inner == 1: + return "" + if output == OutputFormat.LaTeX: + return rf"\sqrt{{{inner}}}" + if output == OutputFormat.Unicode: + return f"√{inner}" + return f"sqrt({inner})" + + +def _fraction_str(num: str, den: str, output: OutputFormat) -> str: + if output == OutputFormat.LaTeX: + return rf"\frac{{{num}}}{{{den}}}" + return f"{num}/{den}" + + +def _render_real(signed_num: int, inner: int, den: int, output: OutputFormat) -> str: + """Render ``signed_num * sqrt(inner) / den`` produced by :func:`_recognize_sqrt`.""" + if signed_num == 0: + return "0" + sign = "-" if signed_num < 0 else "" + magnitude = abs(signed_num) + sqrt_part = _sqrt_str(inner, output) + if inner == 1: + numerator = f"{magnitude}" + elif magnitude == 1: + numerator = sqrt_part + else: + numerator = f"{magnitude}{sqrt_part}" + if den == 1: + return f"{sign}{numerator}" + return f"{sign}{_fraction_str(numerator, str(den), output)}" + + +def _real_to_str(x: float, output: OutputFormat, max_denominator: int, atol: float, rtol: float) -> str | None: + rec = _recognize_sqrt(x, max_denominator, atol, rtol) + if rec is None: + return None + return _render_real(*rec, output) + + +def _render_imaginary(signed_num: int, inner: int, den: int, output: OutputFormat) -> str: + """Render ``signed_num * sqrt(inner) / den * i`` with the unit leading the numerator. + + The imaginary unit is placed at the front of the numerator (e.g. ``i/2``, + ``i√2/2``, ``3i/4``) so that it reads as ``i`` times a real magnitude, with a + unit coefficient collapsing to a bare ``±i``. + """ + sign = "-" if signed_num < 0 else "" + magnitude = abs(signed_num) + unit = _imaginary_unit(output) + coefficient = "" if magnitude == 1 else f"{magnitude}" + numerator = f"{coefficient}{unit}{_sqrt_str(inner, output)}" + if den == 1: + return f"{sign}{numerator}" + return f"{sign}{_fraction_str(numerator, str(den), output)}" + + +def _imaginary_to_str(x: float, output: OutputFormat, max_denominator: int, atol: float, rtol: float) -> str | None: + """Render a purely imaginary value ``x * i``.""" + rec = _recognize_sqrt(x, max_denominator, atol, rtol) + if rec is None: + return None + return _render_imaginary(*rec, output) + + +def _recognize_angle_over_pi(theta: float, max_denominator: int, atol: float, rtol: float) -> Fraction | None: + """Return ``theta / pi`` as a simple fraction, or ``None`` if it is not one.""" + value = theta / pi + frac = Fraction(value).limit_denominator(max_denominator) + if math.isclose(value, float(frac), rel_tol=rtol, abs_tol=atol): + return frac + return None + + +def _exponential_to_str(z: complex, output: OutputFormat, max_denominator: int, atol: float, rtol: float) -> str | None: + """Render ``z`` as ``r e^{iθ}`` when both ``r`` and ``θ / π`` are recognized.""" + theta = math.atan2(z.imag, z.real) + angle_frac = _recognize_angle_over_pi(theta, max_denominator, atol, rtol) + if angle_frac is None or angle_frac == 0: + return None + radius = _real_to_str(math.hypot(z.real, z.imag), output, max_denominator, atol, rtol) + if radius is None: + return None + sign = "-" if angle_frac < 0 else "" + angle_str = angle_to_str(float(abs(angle_frac)), output) + unit = _imaginary_unit(output) + e_sym = r"\mathrm{e}" if output == OutputFormat.LaTeX else "e" + unit_sep = " " if output == OutputFormat.LaTeX else "*" if output == OutputFormat.ASCII else "" + exponent = f"{sign}{unit}{unit_sep}{angle_str}" + body = f"{e_sym}^{{{exponent}}}" if output == OutputFormat.LaTeX else f"{e_sym}^({exponent})" + if radius == "1": + return body + prefix_sep = " " if output == OutputFormat.LaTeX else "·" if output == OutputFormat.Unicode else "*" + return f"{radius}{prefix_sep}{body}" + + +def _cartesian_to_str( + re: float, im: float, output: OutputFormat, max_denominator: int, atol: float, rtol: float +) -> str | None: + """Render ``re + im i`` when both parts are recognized as nice reals.""" + re_str = _real_to_str(re, output, max_denominator, atol, rtol) + im_rec = _recognize_sqrt(im, max_denominator, atol, rtol) + if re_str is None or im_rec is None: + return None + signed_num, inner, den = im_rec + connector = " - " if signed_num < 0 else " + " + unit = _imaginary_unit(output) + if abs(signed_num) == 1 and inner == 1 and den == 1: + imag = unit + else: + imag = f"{_render_real(abs(signed_num), inner, den, output)}{unit}" + return f"{re_str}{connector}{imag}" + + +def _decimal_to_str(z: complex, output: OutputFormat, precision: int, atol: float) -> str: + """Fallback formatting using rounded decimals with ``precision`` significant digits.""" + unit = _imaginary_unit(output) + if abs(z.imag) <= atol: + return f"{z.real:.{precision}g}" + if abs(z.real) <= atol: + return f"{z.imag:.{precision}g}{unit}" + return f"{z.real:.{precision}g}{z.imag:+.{precision}g}{unit}" + + +def complex_to_str( + value: object, + output: OutputFormat, + *, + max_denominator: int = _DEFAULT_MAX_DENOMINATOR, + atol: float = _DEFAULT_ATOL, + rtol: float = _DEFAULT_RTOL, + precision: int = _DEFAULT_PRECISION, +) -> str: + r"""Return a human-friendly string representation of a complex number. + + Common values are rendered exactly rather than as floating-point numbers: + fractions (``0.25`` → ``1/4``), square roots (``0.7071…`` → ``√2/2``) and + complex exponentials (``0.5 + 0.866…j`` → ``e^(iπ/3)``). Values that are not + recognized fall back to a rounded decimal representation, and inputs that + cannot be interpreted as complex numbers (e.g. symbolic parameters) are + returned via :func:`str`. + + Parameters + ---------- + value : object + The number to format. Anything supporting conversion to ``complex`` is + accepted; other objects are stringified. + output : OutputFormat + Desired formatting style: ``Unicode`` (``√``, ``π``), ``LaTeX`` + (``\sqrt``, ``\pi``) or ``ASCII`` (``sqrt``, ``pi``). + max_denominator : int, optional + Maximum denominator used when recognizing rational magnitudes and phases + (default: ``1000``). + atol : float, optional + Absolute tolerance for the recognition heuristics (default: ``1e-9``). + rtol : float, optional + Relative tolerance for the recognition heuristics (default: ``0.0``). + precision : int, optional + Number of significant digits to use for the decimal fallback when a + value is not recognized as an exact form (default: ``4``). + + Returns + ------- + str + The formatted complex number. + + Examples + -------- + >>> complex_to_str(0.25, OutputFormat.ASCII) + '1/4' + >>> complex_to_str(2**-0.5, OutputFormat.Unicode) + '√2/2' + >>> complex_to_str(0.5 + 0.8660254037844386j, OutputFormat.Unicode) + 'e^(iπ/3)' + >>> complex_to_str(0.123456 + 0.234567j, OutputFormat.ASCII, precision=2) + '0.12+0.23i' + """ + if not isinstance(value, (bool, int, float, complex, SupportsComplex)): + return str(value) + z = complex(value) + if abs(z.real) <= atol and abs(z.imag) <= atol: + return "0" + if abs(z.imag) <= atol: + return _real_to_str(z.real, output, max_denominator, atol, rtol) or _decimal_to_str(z, output, precision, atol) + if abs(z.real) <= atol: + return _imaginary_to_str(z.imag, output, max_denominator, atol, rtol) or _decimal_to_str( + z, output, precision, atol + ) + exponential = _exponential_to_str(z, output, max_denominator, atol, rtol) + # The exponential form is the clearest representation on the unit circle, but for + # other moduli a Cartesian form (e.g. ``1 + i`` rather than ``√2·e^(iπ/4)``) reads + # better, so it is preferred when both parts are recognized. + if exponential is not None and math.isclose(math.hypot(z.real, z.imag), 1.0, rel_tol=rtol, abs_tol=atol): + return exponential + cartesian = _cartesian_to_str(z.real, z.imag, output, max_denominator, atol, rtol) + if cartesian is not None: + return cartesian + # Fall back to a radius-prefixed exponential when the Cartesian parts are not nice. + if exponential is not None: + return exponential + return _decimal_to_str(z, output, precision, atol) + + +def _ket_str(ket: str, output: OutputFormat) -> str: + if output == OutputFormat.LaTeX: + return rf"\ket{{{ket}}}" + if output == OutputFormat.Unicode: + return f"|{ket}⟩" + return f"|{ket}>" + + +def _needs_parentheses(coefficient: str) -> bool: + """Whether a coefficient is a sum and must be parenthesized before a ket.""" + return " + " in coefficient or " - " in coefficient + + +def _factor_uniform_magnitude( + amplitudes: Sequence[object], + kets: Sequence[str], + output: OutputFormat, + *, + max_denominator: int, + atol: float, + rtol: float, + precision: int, +) -> str | None: + """Factor a modulus shared by every amplitude, e.g. ``√2/2(|0⟩ + i|1⟩)``. + + Amplitudes that share a common modulus (up to a relative phase) are written as + ``r(φ₀|k₀⟩ + φ₁|k₁⟩ + …)``, with ``r`` the shared modulus and each ``φ`` the + per-component phase ``amplitude / r`` (so a relative phase such as ``i`` is kept + inside the parentheses). Returns ``None`` when the moduli differ, the shared + modulus is trivial (``1``), or the amplitudes are not numeric (e.g. symbolic), + in which case the caller renders term-by-term. + """ + if len(amplitudes) < 2: + return None + try: + values = [complex(amplitude) for amplitude in amplitudes] # type: ignore[call-overload] + except (TypeError, ValueError): + # Non-numeric (e.g. symbolic) amplitudes have no modulus to factor. + return None + radius = abs(values[0]) + if radius == 0 or any(not math.isclose(abs(v), radius, rel_tol=rtol, abs_tol=atol) for v in values): + return None + radius_str = complex_to_str( + radius, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision + ) + # A unit modulus is not worth factoring (it would read as ``1(...)``). + if radius_str == "1": + return None + result = "" + for index, (value, ket) in enumerate(zip(values, kets, strict=True)): + phase = complex_to_str( + value / radius, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision + ) + # A per-component phase is unit-modulus, so it is rendered as 1, -1, ±i, an + # exponential, or a decimal — never a parenthesizable sum. + if phase == "1": + term = ket + elif phase == "-1": + term = f"-{ket}" + else: + term = f"{phase}{ket}" + if index == 0: + result = term + elif term.startswith("-"): + result += f" - {term[1:]}" + else: + result += f" + {term}" + return f"{radius_str}({result})" + + +def statevec_to_str( + statevec: Statevec, + output: OutputFormat, + *, + encoding: _ENCODING = "MSB", + max_denominator: int = _DEFAULT_MAX_DENOMINATOR, + atol: float = _DEFAULT_ATOL, + rtol: float = 0.0, + precision: int = _DEFAULT_PRECISION, +) -> str: + r"""Return a ket-notation string representation of a statevector. + + Amplitudes close to zero are omitted (see :meth:`graphix.sim.statevec.Statevec.to_dict`) + and the remaining ones are pretty-printed with :func:`complex_to_str`. + + Parameters + ---------- + statevec : Statevec + The statevector to format. + output : OutputFormat + Desired formatting style (``ASCII``, ``LaTeX`` or ``Unicode``). + encoding : {"LSB", "MSB"}, optional + Bit-ordering convention for the basis kets (default: ``"MSB"``). + See :meth:`graphix.sim.statevec.Statevec.to_dict`. + max_denominator : int, optional + Maximum denominator used by the amplitude recognition (default: ``1000``). + atol : float, optional + Absolute tolerance used both to drop near-zero amplitudes and for the + recognition heuristics (default: ``1e-9``). + rtol : float, optional + Relative tolerance used both to drop near-zero amplitudes and for the + recognition heuristics (default: ``0.0``). + precision : int, optional + Number of significant digits to use for amplitudes that fall back to a + decimal representation (default: ``4``). + + Returns + ------- + str + The formatted statevector, e.g. ``√2/2(|00⟩ + |01⟩)``. + """ + amplitudes = statevec.to_dict(encoding, rtol=rtol, atol=atol) + if not amplitudes: + return "0" + amps = list(amplitudes.values()) + kets = [_ket_str(ket, output) for ket in amplitudes] + # When every amplitude shares a modulus (up to a relative phase), factor it out, + # e.g. ``√2/2(|0⟩ + i|1⟩)``. + factored = _factor_uniform_magnitude( + amps, kets, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision + ) + if factored is not None: + return factored + coefficients = [ + complex_to_str(amplitude, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision) + for amplitude in amps + ] + result = "" + for index, (coefficient, ket_str) in enumerate(zip(coefficients, kets, strict=True)): + if coefficient == "1": + term = ket_str + elif coefficient == "-1": + term = f"-{ket_str}" + elif _needs_parentheses(coefficient): + term = f"({coefficient}){ket_str}" + else: + term = f"{coefficient}{ket_str}" + if index == 0: + result = term + elif term.startswith("-"): + result += f" - {term[1:]}" + else: + result += f" + {term}" + return result + + +def density_matrix_to_str( + density_matrix: DensityMatrix, + output: OutputFormat, + *, + max_denominator: int = _DEFAULT_MAX_DENOMINATOR, + atol: float = _DEFAULT_ATOL, + rtol: float = _DEFAULT_RTOL, + precision: int = _DEFAULT_PRECISION, +) -> str: + r"""Return a matrix-form string representation of a density matrix. + + Each entry is pretty-printed with :func:`complex_to_str`. ``LaTeX`` output + uses a ``pmatrix`` environment; ``ASCII`` and ``Unicode`` outputs produce a + column-aligned grid. + + Parameters + ---------- + density_matrix : DensityMatrix + The density matrix to format. + output : OutputFormat + Desired formatting style (``ASCII``, ``LaTeX`` or ``Unicode``). + max_denominator : int, optional + Maximum denominator used by the entry recognition (default: ``1000``). + atol : float, optional + Absolute tolerance for the recognition heuristics (default: ``1e-9``). + rtol : float, optional + Relative tolerance for the recognition heuristics (default: ``0.0``). + precision : int, optional + Number of significant digits to use for entries that fall back to a + decimal representation (default: ``4``). + + Returns + ------- + str + The formatted density matrix. + """ + rows = [ + [ + complex_to_str(entry, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision) + for entry in row + ] + for row in density_matrix.rho + ] + if output == OutputFormat.LaTeX: + body = r" \\ ".join(" & ".join(row) for row in rows) + return rf"\begin{{pmatrix}}{body}\end{{pmatrix}}" + widths = [max(len(row[col]) for row in rows) for col in range(len(rows[0]))] if rows else [] + lines = [" ".join(entry.rjust(widths[col]) for col, entry in enumerate(row)) for row in rows] + return "\n".join(f"[ {line} ]" for line in lines) diff --git a/graphix/sim/density_matrix.py b/graphix/sim/density_matrix.py index cb5570e2a..8c6ef8331 100644 --- a/graphix/sim/density_matrix.py +++ b/graphix/sim/density_matrix.py @@ -19,6 +19,7 @@ from graphix import parameter from graphix.channels import KrausChannel from graphix.parameter import Expression, ExpressionOrFloat, ExpressionOrSupportsComplex +from graphix.pretty_print import OutputFormat, density_matrix_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, matmul, outer, tensordot, vdot from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec from graphix.states import BasicStates, State @@ -117,6 +118,44 @@ def __str__(self) -> str: """Return a string description.""" return f"DensityMatrix object, with density matrix {self.rho} and shape {self.dims()}." + def draw( + self, + output: OutputFormat = OutputFormat.Unicode, + *, + max_denominator: int = 1000, + atol: float = 1e-9, + rtol: float = 0.0, + precision: int = 4, + ) -> str: + r"""Return a pretty-printed matrix representation of the density matrix. + + Each entry is rendered with :func:`graphix.pretty_print.complex_to_str`, + so common values appear as exact expressions (e.g. ``1/2``) rather than + floating-point numbers. + + Parameters + ---------- + output : OutputFormat, optional + Desired formatting style. Defaults to :attr:`OutputFormat.Unicode`. + max_denominator : int, optional + Maximum denominator used by the entry recognition (default: ``1000``). + atol : float, optional + Absolute tolerance for the recognition heuristics (default: ``1e-9``). + rtol : float, optional + Relative tolerance for the recognition heuristics (default: ``0.0``). + precision : int, optional + Number of significant digits to use for entries that fall back to a + decimal representation (default: ``4``). + + Returns + ------- + str + The formatted density matrix. + """ + return density_matrix_to_str( + self, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision + ) + @override def add_nodes(self, nqubit: int, data: Data) -> None: r""" diff --git a/graphix/sim/statevec.py b/graphix/sim/statevec.py index 1f7dd23e6..dfde0dfcc 100644 --- a/graphix/sim/statevec.py +++ b/graphix/sim/statevec.py @@ -16,6 +16,7 @@ from graphix import parameter, states from graphix.parameter import Expression, ExpressionOrSupportsComplex, check_expression_or_float +from graphix.pretty_print import OutputFormat, statevec_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, tensordot from graphix.states import BasicStates @@ -520,6 +521,64 @@ def to_prob_dict( """ return self._to_dict_map(lambda x: np.abs(x) ** 2, encoding, rtol=rtol, atol=atol) + def draw( + self, + output: OutputFormat = OutputFormat.Unicode, + *, + encoding: _ENCODING = "MSB", + max_denominator: int = 1000, + atol: float = 1e-9, + rtol: float = 0.0, + precision: int = 4, + ) -> str: + r"""Return a pretty-printed ket-notation representation of the statevector. + + Amplitudes are rendered with :func:`graphix.pretty_print.complex_to_str`, + so common values appear as exact expressions (e.g. ``√2/2``) rather than + floating-point numbers. + + Parameters + ---------- + output : OutputFormat, optional + Desired formatting style. Defaults to :attr:`OutputFormat.Unicode`. + encoding : {"LSB", "MSB"}, optional + Bit-ordering convention for the basis kets (default: ``"MSB"``). + See :meth:`to_dict`. + max_denominator : int, optional + Maximum denominator used by the amplitude recognition (default: ``1000``). + atol : float, optional + Absolute tolerance for dropping near-zero amplitudes and for the + recognition heuristics (default: ``1e-9``). + rtol : float, optional + Relative tolerance for dropping near-zero amplitudes (default: ``0.0``). + precision : int, optional + Number of significant digits to use for amplitudes that fall back to + a decimal representation (default: ``4``). + + Returns + ------- + str + The formatted statevector. + + Examples + -------- + >>> from graphix.transpiler import Circuit + >>> circuit = Circuit(2) + >>> circuit.h(0) + >>> circuit.cz(0, 1) + >>> print(circuit.simulate_statevector().statevec.draw()) + √2/2(|00⟩ + |01⟩) + """ + return statevec_to_str( + self, + output, + encoding=encoding, + max_denominator=max_denominator, + atol=atol, + rtol=rtol, + precision=precision, + ) + def _to_dict_map( self, f: Callable[[npt.NDArray[np.object_ | np.complex128]], npt.NDArray[_ScalarT]], diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 28dedce55..87b3febbc 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -1,8 +1,11 @@ from __future__ import annotations +import cmath +import math from typing import TYPE_CHECKING import networkx as nx +import numpy as np import pytest from numpy.random import PCG64, Generator @@ -13,12 +16,15 @@ from graphix.opengraph import OpenGraph from graphix.parameter import Placeholder from graphix.pattern import Pattern -from graphix.pretty_print import OutputFormat, pattern_to_str +from graphix.pretty_print import OutputFormat, _factor_uniform_magnitude, complex_to_str, pattern_to_str from graphix.random_objects import rand_circuit +from graphix.sim.density_matrix import DensityMatrix +from graphix.sim.statevec import Statevec +from graphix.states import BasicStates from graphix.transpiler import Circuit if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Mapping from graphix.flow.core import PauliFlow @@ -202,3 +208,218 @@ def test_xzcorr_str() -> None: str(flow) == "x(3) = {5}, x(4) = {6}, x(1) = {3}, x(2) = {4}; z(1) = {4, 5}, z(2) = {3, 6}; {1, 2} < {3, 4} < {5, 6}" ) + + +def test_complex_to_str_issue_examples() -> None: + # The three canonical examples from the issue. + assert complex_to_str(0.25, OutputFormat.ASCII) == "1/4" + assert complex_to_str(2**-0.5, OutputFormat.Unicode) == "√2/2" + assert complex_to_str(0.5 + math.sqrt(3) / 2 * 1j, OutputFormat.Unicode) == "e^(iπ/3)" + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (0, {OutputFormat.Unicode: "0"}), + (1e-12, {OutputFormat.Unicode: "0"}), + (1, {OutputFormat.Unicode: "1"}), + (-1, {OutputFormat.Unicode: "-1"}), + (2, {OutputFormat.Unicode: "2"}), + (0.5, {OutputFormat.Unicode: "1/2"}), + (0.25, {OutputFormat.LaTeX: r"\frac{1}{4}"}), + (-0.25, {OutputFormat.Unicode: "-1/4"}), + (2**-0.5, {OutputFormat.Unicode: "√2/2", OutputFormat.LaTeX: r"\frac{\sqrt{2}}{2}"}), + (math.sqrt(3) / 2, {OutputFormat.Unicode: "√3/2"}), + (1j, {OutputFormat.Unicode: "i"}), + (-1j, {OutputFormat.Unicode: "-i"}), + # The imaginary unit leads the numerator (i/2, not 1/2i). + (0.5j, {OutputFormat.Unicode: "i/2", OutputFormat.ASCII: "i/2", OutputFormat.LaTeX: r"\frac{\mathrm{i}}{2}"}), + (-(2**-0.5) * 1j, {OutputFormat.Unicode: "-i√2/2"}), + # Complex exponentials on the unit circle. + (math.cos(math.pi / 4) + math.sin(math.pi / 4) * 1j, {OutputFormat.Unicode: "e^(iπ/4)"}), + # Negative phase keeps the sign inside the exponent. + (math.cos(math.pi / 3) - math.sin(math.pi / 3) * 1j, {OutputFormat.Unicode: "e^(-iπ/3)"}), + ( + 0.5 + math.sqrt(3) / 2 * 1j, + { + OutputFormat.Unicode: "e^(iπ/3)", + OutputFormat.ASCII: "e^(i*pi/3)", + OutputFormat.LaTeX: r"\mathrm{e}^{\mathrm{i} \frac{\pi}{3}}", + }, + ), + # An unrecognized value falls back to a rounded decimal. + (0.123456, {OutputFormat.ASCII: "0.1235"}), + # A non-numeric object is stringified rather than raising. + ("alpha", {OutputFormat.ASCII: "alpha"}), + # |z| != 1 with nice Cartesian parts: the Cartesian form is preferred over the + # radius-prefixed exponential (1 + i rather than √2·e^(iπ/4)). + (1 + 1j, {OutputFormat.Unicode: "1 + i", OutputFormat.ASCII: "1 + i", OutputFormat.LaTeX: r"1 + \mathrm{i}"}), + # When the Cartesian parts are not recognized, the radius-prefixed exponential is the + # last resort before the decimal fallback. + ( + 2 * (math.cos(math.pi / 5) + math.sin(math.pi / 5) * 1j), + { + OutputFormat.Unicode: "2·e^(iπ/5)", + OutputFormat.ASCII: "2*e^(i*pi/5)", + OutputFormat.LaTeX: r"2 \mathrm{e}^{\mathrm{i} \frac{\pi}{5}}", + }, + ), + # Both parts recognized but the phase is not a simple fraction of π: Cartesian form. + ( + 0.5 + 0.25j, + {OutputFormat.Unicode: "1/2 + 1/4i", OutputFormat.LaTeX: r"\frac{1}{2} + \frac{1}{4}\mathrm{i}"}, + ), + # Neither part recognized -> rounded decimal real and imaginary parts. + (0.123456 + 0.234567j, {OutputFormat.Unicode: "0.1235+0.2346i"}), + # Integer multiple of a surd. + (math.sqrt(12), {OutputFormat.Unicode: "2√3"}), + ], +) +def test_complex_to_str_values(value: object, expected: Mapping[OutputFormat, str]) -> None: + for output, text in expected.items(): + assert complex_to_str(value, output) == text + + +def test_statevec_draw() -> None: + bell = Statevec([2**-0.5, 0, 0, 2**-0.5]) + # A magnitude shared by every amplitude is factored out. + assert bell.draw(OutputFormat.Unicode) == "√2/2(|00⟩ + |11⟩)" + assert bell.draw(OutputFormat.ASCII) == "sqrt(2)/2(|00> + |11>)" + # LaTeX uses the \ket{...} macro for the basis kets. + assert bell.draw(OutputFormat.LaTeX) == r"\frac{\sqrt{2}}{2}(\ket{00} + \ket{11})" + + +def test_statevec_draw_single_basis_state() -> None: + state = Statevec(data=[BasicStates.ZERO, BasicStates.ONE]) + assert state.draw(OutputFormat.Unicode) == "|01⟩" + # LaTeX ket notation for a bare basis state. + assert state.draw(OutputFormat.LaTeX) == r"\ket{01}" + # LSB encoding reverses the ket label. + assert state.draw(OutputFormat.Unicode, encoding="LSB") == "|10⟩" + + +def test_density_matrix_draw() -> None: + dm = DensityMatrix(data=[BasicStates.ZERO]) + assert dm.draw(OutputFormat.ASCII) == "[ 1 0 ]\n[ 0 0 ]" + assert dm.draw(OutputFormat.LaTeX) == r"\begin{pmatrix}1 & 0 \\ 0 & 0\end{pmatrix}" + + +def test_statevec_draw_negative_and_parenthesized() -> None: + # The shared 1/2 magnitude is factored out, with the signs kept inside the parentheses. + neg = Statevec([0.5, -0.5, 0.5, 0.5]) + assert neg.draw(OutputFormat.Unicode) == "1/2(|00⟩ - |01⟩ + |10⟩ + |11⟩)" + # A compound (cartesian) amplitude is parenthesized before the ket. Build from a + # numpy array so the amplitudes are ``numpy.complex128`` (Python's ``complex`` only + # gained ``__complex__`` in 3.11, so a bare ``complex`` is rejected on 3.10). + binomial = Statevec(np.array([0.5 + 0.25j, (1 - abs(0.5 + 0.25j) ** 2) ** 0.5])) + assert binomial.draw(OutputFormat.Unicode) == "(1/2 + 1/4i)|0⟩ + √11/4|1⟩" + # A unit negative amplitude collapses to a bare `-|ket⟩`. + assert Statevec([-1.0, 0.0]).draw(OutputFormat.Unicode) == "-|0⟩" + + +def test_complex_to_str_precision_is_configurable() -> None: + z = 0.123456 + 0.234567j + assert complex_to_str(z, OutputFormat.ASCII, precision=2) == "0.12+0.23i" + assert complex_to_str(z, OutputFormat.ASCII, precision=6) == "0.123456+0.234567i" + # The default keeps the previous behaviour (four significant digits). + assert complex_to_str(z, OutputFormat.ASCII) == "0.1235+0.2346i" + + +def test_complex_to_str_rtol_controls_recognition() -> None: + # A value slightly off 1/2: with the default (tight) tolerances it is not recognized as a + # fraction and falls back to a decimal; a looser relative tolerance recognizes it as 1/2. + x = 0.500001 + assert complex_to_str(x, OutputFormat.Unicode) == "0.5" + assert complex_to_str(x, OutputFormat.Unicode, rtol=1e-4) == "1/2" + + +def test_density_matrix_draw_rtol() -> None: + # `rtol` is accepted by `DensityMatrix.draw` and threaded to the entry recognition. + dm = DensityMatrix(data=[BasicStates.PLUS]) + assert dm.draw(OutputFormat.Unicode) == "[ 1/2 1/2 ]\n[ 1/2 1/2 ]" + assert dm.draw(OutputFormat.Unicode, rtol=1e-4) == "[ 1/2 1/2 ]\n[ 1/2 1/2 ]" + + +def test_statevec_draw_factor_relative_phase() -> None: + # A modulus shared up to a relative phase is still factored, with the phase kept + # inside the parentheses. + assert Statevec(data=BasicStates.PLUS).draw(OutputFormat.Unicode) == "√2/2(|0⟩ + |1⟩)" + assert Statevec(data=BasicStates.PLUS_I).draw(OutputFormat.Unicode) == "√2/2(|0⟩ + i|1⟩)" + assert Statevec(data=BasicStates.MINUS_I).draw(OutputFormat.Unicode) == "√2/2(|0⟩ - i|1⟩)" + assert ( + Statevec(data=BasicStates.PLUS_I).draw(OutputFormat.LaTeX) == r"\frac{\sqrt{2}}{2}(\ket{0} + \mathrm{i}\ket{1})" + ) + + +def test_factor_uniform_magnitude_edge_cases() -> None: + kets = ["|0⟩", "|1⟩"] + # Non-numeric (e.g. symbolic) amplitudes have no modulus -> None (term-by-term fallback). + assert ( + _factor_uniform_magnitude( + ["alpha", "beta"], kets, OutputFormat.Unicode, max_denominator=1000, atol=1e-9, rtol=0.0, precision=4 + ) + is None + ) + # Zero modulus -> None. + assert ( + _factor_uniform_magnitude( + [0, 0], kets, OutputFormat.Unicode, max_denominator=1000, atol=1e-9, rtol=0.0, precision=4 + ) + is None + ) + # A unit modulus is not worth factoring -> None. + assert ( + _factor_uniform_magnitude( + [1, 1j], kets, OutputFormat.Unicode, max_denominator=1000, atol=1e-9, rtol=0.0, precision=4 + ) + is None + ) + # A non-nice relative phase still factors, with the phase as a decimal coefficient. + factored = _factor_uniform_magnitude( + [0.5, 0.5 * cmath.exp(1j * 0.3)], + kets, + OutputFormat.Unicode, + max_denominator=1000, + atol=1e-9, + rtol=0.0, + precision=4, + ) + assert factored is not None + assert factored.startswith("1/2(|0⟩ + ") + + +def test_draw_max_denominator() -> None: + # `max_denominator` caps the denominators the recognition will accept. + dm = DensityMatrix(data=[BasicStates.PLUS]) + assert dm.draw(OutputFormat.Unicode) == "[ 1/2 1/2 ]\n[ 1/2 1/2 ]" + # With max_denominator=1, 1/2 can no longer be recognized and falls back to a decimal. + assert dm.draw(OutputFormat.Unicode, max_denominator=1) == "[ 0.5 0.5 ]\n[ 0.5 0.5 ]" + # Same effect on the statevector draw (√2/2 -> 0.7071). + sv = Statevec([2**-0.5, 2**-0.5]) + assert sv.draw(OutputFormat.Unicode) == "√2/2(|0⟩ + |1⟩)" + assert sv.draw(OutputFormat.Unicode, max_denominator=1) == "0.7071(|0⟩ + |1⟩)" + + +def test_complex_to_str_edge_cases() -> None: + # A tiny non-zero real collapses to "0" (zero branches of the square-free + # decomposition and the real renderer). + assert complex_to_str(1e-7, OutputFormat.Unicode) == "0" + # A purely imaginary value with no nice form falls back to a decimal imaginary part. + assert complex_to_str(0.234567j, OutputFormat.Unicode) == "0.2346i" + # A recognized phase with an unrecognized radius (π·e^(iπ/4)) cannot use the exponential + # or Cartesian forms and falls back to a decimal. + assert complex_to_str(math.pi * cmath.exp(1j * math.pi / 4), OutputFormat.Unicode) == "2.221+2.221i" + + +def test_statevec_draw_all_amplitudes_dropped() -> None: + # A tolerance large enough to drop every amplitude prints the statevector as "0". + bell = Statevec([2**-0.5, 0, 0, 2**-0.5]) + assert bell.draw(OutputFormat.Unicode, atol=1.0) == "0" + + +def test_statevec_draw_non_uniform_not_factored() -> None: + # Amplitudes with different magnitudes are not factored; each term keeps its coefficient. + # The negative second amplitude also exercises the ` - ` separator in the term-by-term + # fallback (the factoring path handles the uniform-magnitude negative case separately). + state = Statevec([math.sqrt(3) / 2, -0.5]) + assert state.draw(OutputFormat.Unicode) == "√3/2|0⟩ - 1/2|1⟩"