From 4312e84f4b73c567e269353a37144391d5c873dd Mon Sep 17 00:00:00 2001 From: Vinny010 Date: Wed, 3 Jun 2026 08:21:42 +0400 Subject: [PATCH 1/9] Add pretty-printing for Statevec and DensityMatrix (closes #501) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add human-friendly rendering of quantum amplitudes and states: - `complex_to_str` in `pretty_print.py` recognizes common values and renders them exactly instead of as floats: fractions (`1/4`), square roots (`√2/2`) and complex exponentials (`e^(iπ/3)`). Recognition uses a square-then-rationalize heuristic and reuses the existing `angle_to_str` for the exponential phase. Supports ASCII, Unicode and LaTeX output. - `statevec_to_str` + `Statevec.draw` render a statevector in ket notation (e.g. `√2/2|00⟩ + √2/2|01⟩`), honouring the existing `encoding` parameter from `Statevec.to_dict`. - `density_matrix_to_str` + `DensityMatrix.draw` render a density matrix as a column-aligned grid (ASCII/Unicode) or a LaTeX `pmatrix`. - Tests covering the issue examples plus edge cases (zero, negative, pure imaginary, LaTeX, decimal fallback, symbolic), with numpy-style docs. Co-Authored-By: Claude Opus 4.8 --- graphix/pretty_print.py | 376 +++++++++++++++++++++++++++++++++- graphix/sim/density_matrix.py | 30 +++ graphix/sim/statevec.py | 47 +++++ tests/test_pretty_print.py | 75 ++++++- 4 files changed, 526 insertions(+), 2 deletions(-) diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index b2bd72e8f..bd78f6305 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -8,7 +8,7 @@ from enum import Enum from fractions import Fraction from math import pi -from typing import TYPE_CHECKING, SupportsFloat +from typing import TYPE_CHECKING, SupportsComplex, SupportsFloat # `assert_never` introduced in Python 3.11 from typing_extensions import assert_never @@ -26,6 +26,8 @@ from graphix.flow.core import PauliFlow, XZCorrections from graphix.fundamentals import Angle from graphix.pattern import Pattern + from graphix.sim.density_matrix import DensityMatrix + from graphix.sim.statevec import _ENCODING, Statevec class OutputFormat(Enum): @@ -439,3 +441,375 @@ def xzcorr_to_str(xzcorr: XZCorrections[AbstractMeasurement], output: OutputForm partial_order_to_str(xzcorr.partial_order_layers, output), ) ) + + +# --- Complex amplitude and quantum-state pretty-printing --------------------- +# +# The recognition of "nice" real numbers (fractions, square roots) relies on a +# square-then-rationalize trick: a real ``x`` is matched against ``sqrt(p / q)`` +# by approximating ``x ** 2`` with a rational ``p / q``. This single mechanism +# uniformly handles plain fractions (``1/4``), surds (``√2/2``, ``√3/2``) and, +# combined with :func:`angle_to_str`, the phase of exponentials (``e^{iπ/3}``). + +_DEFAULT_MAX_DENOMINATOR = 1000 +_DEFAULT_ATOL = 1e-9 + + +def _squarefree_decomposition(n: int) -> tuple[int, int]: + """Decompose a non-negative integer as ``outer ** 2 * inner`` with ``inner`` squarefree. + + Parameters + ---------- + n : int + Non-negative integer to decompose. + + Returns + ------- + tuple[int, int] + ``(outer, inner)`` such that ``outer ** 2 * inner == n`` and ``inner`` is + squarefree. ``n == 0`` returns ``(0, 1)``. + """ + if n == 0: + return 0, 1 + outer = 1 + inner = n + d = 2 + while d * d <= inner: + while inner % (d * d) == 0: + inner //= d * d + outer *= d + d += 1 + return outer, inner + + +def _recognize_real(x: float, max_denominator: int, atol: float) -> tuple[int, int, int] | None: + """Recognize a real number as ``signed_num * sqrt(inner) / den``. + + The recognition approximates ``x ** 2`` by a rational ``p / q``; on success, + ``x = ±sqrt(p / q)`` is rewritten with a rationalized, fully-reduced + denominator. + + Parameters + ---------- + x : float + Real number to recognize. + max_denominator : int + Maximum denominator allowed when approximating ``x ** 2`` by a rational. + atol : float + Absolute tolerance for that rational approximation. + + Returns + ------- + tuple[int, int, int] or None + ``(signed_num, inner, den)`` with ``den > 0`` and ``inner`` a positive + squarefree integer, encoding ``x = signed_num * sqrt(inner) / den``. + Returns ``None`` when ``x`` is not recognized as such a value. + """ + if x == 0: + return 0, 1, 1 + square = Fraction(x * x).limit_denominator(max_denominator) + if not math.isclose(x * x, float(square), abs_tol=atol): + return None + num_outer, num_inner = _squarefree_decomposition(square.numerator) + den_outer, den_inner = _squarefree_decomposition(square.denominator) + # x = ±(num_outer √num_inner) / (den_outer √den_inner); rationalize by √den_inner. + combined_outer, inner = _squarefree_decomposition(num_inner * den_inner) + num = num_outer * combined_outer + den = den_outer * den_inner + divisor = math.gcd(num, den) + num //= divisor + den //= divisor + sign = -1 if x < 0 else 1 + return sign * num, inner, den + + +def _imaginary_unit(output: OutputFormat) -> str: + return r"\mathrm{i}" if output == OutputFormat.LaTeX else "i" + + +def _sqrt_str(inner: int, output: OutputFormat) -> str: + """Return the string for ``sqrt(inner)`` (empty when ``inner == 1``).""" + if inner == 1: + return "" + if output == OutputFormat.LaTeX: + return rf"\sqrt{{{inner}}}" + if output == OutputFormat.Unicode: + return f"√{inner}" + return f"sqrt({inner})" + + +def _fraction_str(num: str, den: str, output: OutputFormat) -> str: + if output == OutputFormat.LaTeX: + return rf"\frac{{{num}}}{{{den}}}" + return f"{num}/{den}" + + +def _render_real(signed_num: int, inner: int, den: int, output: OutputFormat) -> str: + """Render ``signed_num * sqrt(inner) / den`` produced by :func:`_recognize_real`.""" + if signed_num == 0: + return "0" + sign = "-" if signed_num < 0 else "" + magnitude = abs(signed_num) + sqrt_part = _sqrt_str(inner, output) + if inner == 1: + numerator = f"{magnitude}" + elif magnitude == 1: + numerator = sqrt_part + else: + numerator = f"{magnitude}{sqrt_part}" + if den == 1: + return f"{sign}{numerator}" + return f"{sign}{_fraction_str(numerator, str(den), output)}" + + +def _real_to_str(x: float, output: OutputFormat, max_denominator: int, atol: float) -> str | None: + rec = _recognize_real(x, max_denominator, atol) + if rec is None: + return None + return _render_real(*rec, output) + + +def _imaginary_to_str(x: float, output: OutputFormat, max_denominator: int, atol: float) -> str | None: + """Render a purely imaginary value ``x * i``.""" + rec = _recognize_real(x, max_denominator, atol) + if rec is None: + return None + signed_num, inner, den = rec + unit = _imaginary_unit(output) + # A unit coefficient collapses to just ``±i``. + if abs(signed_num) == 1 and inner == 1 and den == 1: + return f"{'-' if signed_num < 0 else ''}{unit}" + return f"{_render_real(signed_num, inner, den, output)}{unit}" + + +def _recognize_angle_over_pi(theta: float, max_denominator: int, atol: float) -> Fraction | None: + """Return ``theta / pi`` as a simple fraction, or ``None`` if it is not one.""" + value = theta / pi + frac = Fraction(value).limit_denominator(max_denominator) + if math.isclose(value, float(frac), abs_tol=atol): + return frac + return None + + +def _exponential_to_str(z: complex, output: OutputFormat, max_denominator: int, atol: float) -> str | None: + """Render ``z`` as ``r e^{iθ}`` when both ``r`` and ``θ / π`` are recognized.""" + theta = math.atan2(z.imag, z.real) + angle_frac = _recognize_angle_over_pi(theta, max_denominator, atol) + if angle_frac is None or angle_frac == 0: + return None + radius = _real_to_str(math.hypot(z.real, z.imag), output, max_denominator, atol) + if radius is None: + return None + sign = "-" if angle_frac < 0 else "" + angle_str = angle_to_str(float(abs(angle_frac)), output) + unit = _imaginary_unit(output) + e_sym = r"\mathrm{e}" if output == OutputFormat.LaTeX else "e" + unit_sep = " " if output == OutputFormat.LaTeX else "*" if output == OutputFormat.ASCII else "" + exponent = f"{sign}{unit}{unit_sep}{angle_str}" + body = f"{e_sym}^{{{exponent}}}" if output == OutputFormat.LaTeX else f"{e_sym}^({exponent})" + if radius == "1": + return body + prefix_sep = " " if output == OutputFormat.LaTeX else "·" if output == OutputFormat.Unicode else "*" + return f"{radius}{prefix_sep}{body}" + + +def _cartesian_to_str(re: float, im: float, output: OutputFormat, max_denominator: int, atol: float) -> str | None: + """Render ``re + im i`` when both parts are recognized as nice reals.""" + re_str = _real_to_str(re, output, max_denominator, atol) + im_rec = _recognize_real(im, max_denominator, atol) + if re_str is None or im_rec is None: + return None + signed_num, inner, den = im_rec + connector = " - " if signed_num < 0 else " + " + unit = _imaginary_unit(output) + if abs(signed_num) == 1 and inner == 1 and den == 1: + imag = unit + else: + imag = f"{_render_real(abs(signed_num), inner, den, output)}{unit}" + return f"{re_str}{connector}{imag}" + + +def _decimal_to_str(z: complex, output: OutputFormat) -> str: + """Fallback formatting using rounded decimals.""" + unit = _imaginary_unit(output) + if abs(z.imag) <= _DEFAULT_ATOL: + return f"{z.real:.4g}" + if abs(z.real) <= _DEFAULT_ATOL: + return f"{z.imag:.4g}{unit}" + return f"{z.real:.4g}{z.imag:+.4g}{unit}" + + +def complex_to_str( + value: object, + output: OutputFormat, + *, + max_denominator: int = _DEFAULT_MAX_DENOMINATOR, + atol: float = _DEFAULT_ATOL, +) -> str: + r"""Return a human-friendly string representation of a complex number. + + Common values are rendered exactly rather than as floating-point numbers: + fractions (``0.25`` → ``1/4``), square roots (``0.7071…`` → ``√2/2``) and + complex exponentials (``0.5 + 0.866…j`` → ``e^(iπ/3)``). Values that are not + recognized fall back to a rounded decimal representation, and inputs that + cannot be interpreted as complex numbers (e.g. symbolic parameters) are + returned via :func:`str`. + + Parameters + ---------- + value : object + The number to format. Anything supporting conversion to ``complex`` is + accepted; other objects are stringified. + output : OutputFormat + Desired formatting style: ``Unicode`` (``√``, ``π``), ``LaTeX`` + (``\sqrt``, ``\pi``) or ``ASCII`` (``sqrt``, ``pi``). + max_denominator : int, optional + Maximum denominator used when recognizing rational magnitudes and phases + (default: ``1000``). + atol : float, optional + Absolute tolerance for the recognition heuristics (default: ``1e-9``). + + Returns + ------- + str + The formatted complex number. + + Examples + -------- + >>> complex_to_str(0.25, OutputFormat.ASCII) + '1/4' + >>> complex_to_str(2**-0.5, OutputFormat.Unicode) + '√2/2' + >>> complex_to_str(0.5 + 0.8660254037844386j, OutputFormat.Unicode) + 'e^(iπ/3)' + """ + if not isinstance(value, (bool, int, float, complex, SupportsComplex)): + return str(value) + z = complex(value) + if abs(z.real) <= atol and abs(z.imag) <= atol: + return "0" + if abs(z.imag) <= atol: + return _real_to_str(z.real, output, max_denominator, atol) or _decimal_to_str(z, output) + if abs(z.real) <= atol: + return _imaginary_to_str(z.imag, output, max_denominator, atol) or _decimal_to_str(z, output) + exponential = _exponential_to_str(z, output, max_denominator, atol) + if exponential is not None: + return exponential + cartesian = _cartesian_to_str(z.real, z.imag, output, max_denominator, atol) + if cartesian is not None: + return cartesian + return _decimal_to_str(z, output) + + +def _ket_str(ket: str, output: OutputFormat) -> str: + if output == OutputFormat.LaTeX: + return rf"\lvert {ket}\rangle" + if output == OutputFormat.Unicode: + return f"|{ket}⟩" + return f"|{ket}>" + + +def _needs_parentheses(coefficient: str) -> bool: + """Whether a coefficient is a sum and must be parenthesized before a ket.""" + return " + " in coefficient or " - " in coefficient + + +def statevec_to_str( + statevec: Statevec, + output: OutputFormat, + *, + encoding: _ENCODING = "MSB", + max_denominator: int = _DEFAULT_MAX_DENOMINATOR, + atol: float = _DEFAULT_ATOL, + rtol: float = 0.0, +) -> str: + r"""Return a ket-notation string representation of a statevector. + + Amplitudes close to zero are omitted (see :meth:`graphix.sim.statevec.Statevec.to_dict`) + and the remaining ones are pretty-printed with :func:`complex_to_str`. + + Parameters + ---------- + statevec : Statevec + The statevector to format. + output : OutputFormat + Desired formatting style (``ASCII``, ``LaTeX`` or ``Unicode``). + encoding : {"LSB", "MSB"}, optional + Bit-ordering convention for the basis kets (default: ``"MSB"``). + See :meth:`graphix.sim.statevec.Statevec.to_dict`. + max_denominator : int, optional + Maximum denominator used by the amplitude recognition (default: ``1000``). + atol : float, optional + Absolute tolerance used both to drop near-zero amplitudes and for the + recognition heuristics (default: ``1e-9``). + rtol : float, optional + Relative tolerance used to drop near-zero amplitudes (default: ``0.0``). + + Returns + ------- + str + The formatted statevector, e.g. ``√2/2|00⟩ + √2/2|01⟩``. + """ + amplitudes = statevec.to_dict(encoding, rtol=rtol, atol=atol) + if not amplitudes: + return "0" + result = "" + for index, (ket, amplitude) in enumerate(amplitudes.items()): + coefficient = complex_to_str(amplitude, output, max_denominator=max_denominator, atol=atol) + ket_str = _ket_str(ket, output) + if coefficient == "1": + term = ket_str + elif coefficient == "-1": + term = f"-{ket_str}" + elif _needs_parentheses(coefficient): + term = f"({coefficient}){ket_str}" + else: + term = f"{coefficient}{ket_str}" + if index == 0: + result = term + elif term.startswith("-"): + result += f" - {term[1:]}" + else: + result += f" + {term}" + return result + + +def density_matrix_to_str( + density_matrix: DensityMatrix, + output: OutputFormat, + *, + max_denominator: int = _DEFAULT_MAX_DENOMINATOR, + atol: float = _DEFAULT_ATOL, +) -> str: + r"""Return a matrix-form string representation of a density matrix. + + Each entry is pretty-printed with :func:`complex_to_str`. ``LaTeX`` output + uses a ``pmatrix`` environment; ``ASCII`` and ``Unicode`` outputs produce a + column-aligned grid. + + Parameters + ---------- + density_matrix : DensityMatrix + The density matrix to format. + output : OutputFormat + Desired formatting style (``ASCII``, ``LaTeX`` or ``Unicode``). + max_denominator : int, optional + Maximum denominator used by the entry recognition (default: ``1000``). + atol : float, optional + Absolute tolerance for the recognition heuristics (default: ``1e-9``). + + Returns + ------- + str + The formatted density matrix. + """ + rows = [ + [complex_to_str(entry, output, max_denominator=max_denominator, atol=atol) for entry in row] + for row in density_matrix.rho + ] + if output == OutputFormat.LaTeX: + body = r" \\ ".join(" & ".join(row) for row in rows) + return rf"\begin{{pmatrix}}{body}\end{{pmatrix}}" + widths = [max(len(row[col]) for row in rows) for col in range(len(rows[0]))] if rows else [] + lines = [" ".join(entry.rjust(widths[col]) for col, entry in enumerate(row)) for row in rows] + return "\n".join(f"[ {line} ]" for line in lines) diff --git a/graphix/sim/density_matrix.py b/graphix/sim/density_matrix.py index cb5570e2a..0f0af891b 100644 --- a/graphix/sim/density_matrix.py +++ b/graphix/sim/density_matrix.py @@ -19,6 +19,7 @@ from graphix import parameter from graphix.channels import KrausChannel from graphix.parameter import Expression, ExpressionOrFloat, ExpressionOrSupportsComplex +from graphix.pretty_print import OutputFormat, density_matrix_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, matmul, outer, tensordot, vdot from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec from graphix.states import BasicStates, State @@ -117,6 +118,35 @@ def __str__(self) -> str: """Return a string description.""" return f"DensityMatrix object, with density matrix {self.rho} and shape {self.dims()}." + def draw( + self, + output: OutputFormat = OutputFormat.Unicode, + *, + max_denominator: int = 1000, + atol: float = 1e-9, + ) -> str: + r"""Return a pretty-printed matrix representation of the density matrix. + + Each entry is rendered with :func:`graphix.pretty_print.complex_to_str`, + so common values appear as exact expressions (e.g. ``1/2``) rather than + floating-point numbers. + + Parameters + ---------- + output : OutputFormat, optional + Desired formatting style. Defaults to :attr:`OutputFormat.Unicode`. + max_denominator : int, optional + Maximum denominator used by the entry recognition (default: ``1000``). + atol : float, optional + Absolute tolerance for the recognition heuristics (default: ``1e-9``). + + Returns + ------- + str + The formatted density matrix. + """ + return density_matrix_to_str(self, output, max_denominator=max_denominator, atol=atol) + @override def add_nodes(self, nqubit: int, data: Data) -> None: r""" diff --git a/graphix/sim/statevec.py b/graphix/sim/statevec.py index 1f7dd23e6..9a6d66749 100644 --- a/graphix/sim/statevec.py +++ b/graphix/sim/statevec.py @@ -16,6 +16,7 @@ from graphix import parameter, states from graphix.parameter import Expression, ExpressionOrSupportsComplex, check_expression_or_float +from graphix.pretty_print import OutputFormat, statevec_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, tensordot from graphix.states import BasicStates @@ -520,6 +521,52 @@ def to_prob_dict( """ return self._to_dict_map(lambda x: np.abs(x) ** 2, encoding, rtol=rtol, atol=atol) + def draw( + self, + output: OutputFormat = OutputFormat.Unicode, + *, + encoding: _ENCODING = "MSB", + max_denominator: int = 1000, + atol: float = 1e-9, + rtol: float = 0.0, + ) -> str: + r"""Return a pretty-printed ket-notation representation of the statevector. + + Amplitudes are rendered with :func:`graphix.pretty_print.complex_to_str`, + so common values appear as exact expressions (e.g. ``√2/2``) rather than + floating-point numbers. + + Parameters + ---------- + output : OutputFormat, optional + Desired formatting style. Defaults to :attr:`OutputFormat.Unicode`. + encoding : {"LSB", "MSB"}, optional + Bit-ordering convention for the basis kets (default: ``"MSB"``). + See :meth:`to_dict`. + max_denominator : int, optional + Maximum denominator used by the amplitude recognition (default: ``1000``). + atol : float, optional + Absolute tolerance for dropping near-zero amplitudes and for the + recognition heuristics (default: ``1e-9``). + rtol : float, optional + Relative tolerance for dropping near-zero amplitudes (default: ``0.0``). + + Returns + ------- + str + The formatted statevector. + + Examples + -------- + >>> from graphix.transpiler import Circuit + >>> circuit = Circuit(2) + >>> circuit.h(0) + >>> circuit.cz(0, 1) + >>> print(circuit.simulate_statevector().statevec.draw()) + √2/2|00⟩ + √2/2|01⟩ + """ + return statevec_to_str(self, output, encoding=encoding, max_denominator=max_denominator, atol=atol, rtol=rtol) + def _to_dict_map( self, f: Callable[[npt.NDArray[np.object_ | np.complex128]], npt.NDArray[_ScalarT]], diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 28dedce55..deef98fc6 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math from typing import TYPE_CHECKING import networkx as nx @@ -13,8 +14,11 @@ from graphix.opengraph import OpenGraph from graphix.parameter import Placeholder from graphix.pattern import Pattern -from graphix.pretty_print import OutputFormat, pattern_to_str +from graphix.pretty_print import OutputFormat, complex_to_str, pattern_to_str from graphix.random_objects import rand_circuit +from graphix.sim.density_matrix import DensityMatrix +from graphix.sim.statevec import Statevec +from graphix.states import BasicStates from graphix.transpiler import Circuit if TYPE_CHECKING: @@ -202,3 +206,72 @@ def test_xzcorr_str() -> None: str(flow) == "x(3) = {5}, x(4) = {6}, x(1) = {3}, x(2) = {4}; z(1) = {4, 5}, z(2) = {3, 6}; {1, 2} < {3, 4} < {5, 6}" ) + + +def test_complex_to_str_issue_examples() -> None: + # The three canonical examples from the issue. + assert complex_to_str(0.25, OutputFormat.ASCII) == "1/4" + assert complex_to_str(2**-0.5, OutputFormat.Unicode) == "√2/2" + assert complex_to_str(0.5 + math.sqrt(3) / 2 * 1j, OutputFormat.Unicode) == "e^(iπ/3)" + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (0, "0"), + (1e-12, "0"), + (1, "1"), + (-1, "-1"), + (2, "2"), + (0.5, "1/2"), + (-0.25, "-1/4"), + (2**-0.5, "√2/2"), + (math.sqrt(3) / 2, "√3/2"), + (1j, "i"), + (-1j, "-i"), + (0.5j, "1/2i"), + (-(2**-0.5) * 1j, "-√2/2i"), + ], +) +def test_complex_to_str_unicode_values(value: complex, expected: str) -> None: + assert complex_to_str(value, OutputFormat.Unicode) == expected + + +def test_complex_to_str_exponentials() -> None: + assert complex_to_str(1j, OutputFormat.Unicode) == "i" + assert complex_to_str(math.cos(math.pi / 4) + math.sin(math.pi / 4) * 1j, OutputFormat.Unicode) == "e^(iπ/4)" + # Negative phase keeps the sign inside the exponent. + assert complex_to_str(math.cos(math.pi / 3) - math.sin(math.pi / 3) * 1j, OutputFormat.Unicode) == "e^(-iπ/3)" + assert complex_to_str(0.5 + math.sqrt(3) / 2 * 1j, OutputFormat.ASCII) == "e^(i*pi/3)" + + +def test_complex_to_str_latex() -> None: + assert complex_to_str(2**-0.5, OutputFormat.LaTeX) == r"\frac{\sqrt{2}}{2}" + assert complex_to_str(0.25, OutputFormat.LaTeX) == r"\frac{1}{4}" + assert complex_to_str(0.5 + math.sqrt(3) / 2 * 1j, OutputFormat.LaTeX) == r"\mathrm{e}^{\mathrm{i} \frac{\pi}{3}}" + + +def test_complex_to_str_fallback_and_symbolic() -> None: + # An unrecognized value falls back to a rounded decimal. + assert complex_to_str(0.123456, OutputFormat.ASCII) == "0.1235" + # A non-numeric object is stringified rather than raising. + assert complex_to_str("alpha", OutputFormat.ASCII) == "alpha" + + +def test_statevec_draw() -> None: + bell = Statevec([2**-0.5, 0, 0, 2**-0.5]) + assert bell.draw(OutputFormat.Unicode) == "√2/2|00⟩ + √2/2|11⟩" + assert bell.draw(OutputFormat.ASCII) == "sqrt(2)/2|00> + sqrt(2)/2|11>" + + +def test_statevec_draw_single_basis_state() -> None: + state = Statevec(data=[BasicStates.ZERO, BasicStates.ONE]) + assert state.draw(OutputFormat.Unicode) == "|01⟩" + # LSB encoding reverses the ket label. + assert state.draw(OutputFormat.Unicode, encoding="LSB") == "|10⟩" + + +def test_density_matrix_draw() -> None: + dm = DensityMatrix(data=[BasicStates.ZERO]) + assert dm.draw(OutputFormat.ASCII) == "[ 1 0 ]\n[ 0 0 ]" + assert dm.draw(OutputFormat.LaTeX) == r"\begin{pmatrix}1 & 0 \\ 0 & 0\end{pmatrix}" From bd488e09f6d0ea1a5dba6d871e0293b93507c998 Mon Sep 17 00:00:00 2001 From: Vinny010 Date: Thu, 4 Jun 2026 10:12:04 +0400 Subject: [PATCH 2/9] Add tests covering complex/statevec formatter edge cases Cover the exponential-with-radius, cartesian, complex-decimal-fallback and LaTeX/ASCII imaginary branches of complex_to_str, the integer-times-root render path, and the negative-term, parenthesized-coefficient and unit-negative branches of Statevec.draw, addressing the patch-coverage gap reported on the PR. Co-Authored-By: Claude Opus 4.8 --- tests/test_pretty_print.py | 39 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index deef98fc6..204b29fb2 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -275,3 +275,42 @@ def test_density_matrix_draw() -> None: dm = DensityMatrix(data=[BasicStates.ZERO]) assert dm.draw(OutputFormat.ASCII) == "[ 1 0 ]\n[ 0 0 ]" assert dm.draw(OutputFormat.LaTeX) == r"\begin{pmatrix}1 & 0 \\ 0 & 0\end{pmatrix}" + + +def test_complex_to_str_exponential_with_radius() -> None: + # |z| != 1: the radius prefixes the exponential form (1 + i = √2 e^{iπ/4}). + assert complex_to_str(1 + 1j, OutputFormat.Unicode) == "√2·e^(iπ/4)" + assert complex_to_str(1 + 1j, OutputFormat.ASCII) == "sqrt(2)*e^(i*pi/4)" + assert complex_to_str(1 + 1j, OutputFormat.LaTeX) == r"\sqrt{2} \mathrm{e}^{\mathrm{i} \frac{\pi}{4}}" + + +def test_complex_to_str_cartesian_form() -> None: + # Both parts are recognized but the phase is not a simple fraction of π, so the + # cartesian form is used instead of the exponential one. + assert complex_to_str(0.5 + 0.25j, OutputFormat.Unicode) == "1/2 + 1/4i" + assert complex_to_str(0.5 + 0.25j, OutputFormat.LaTeX) == r"\frac{1}{2} + \frac{1}{4}\mathrm{i}" + + +def test_complex_to_str_complex_decimal_fallback() -> None: + # Neither part is a recognized value -> rounded decimal real and imaginary parts. + assert complex_to_str(0.123456 + 0.234567j, OutputFormat.Unicode) == "0.1235+0.2346i" + + +def test_complex_to_str_imaginary_formats() -> None: + assert complex_to_str(0.5j, OutputFormat.LaTeX) == r"\frac{1}{2}\mathrm{i}" + assert complex_to_str(0.5j, OutputFormat.ASCII) == "1/2i" + + +def test_complex_to_str_integer_times_sqrt() -> None: + assert complex_to_str(math.sqrt(12), OutputFormat.Unicode) == "2√3" + + +def test_statevec_draw_negative_and_parenthesized() -> None: + # Negative amplitudes use a `-` separator between terms. + neg = Statevec([0.5, -0.5, 0.5, 0.5]) + assert neg.draw(OutputFormat.Unicode) == "1/2|00⟩ - 1/2|01⟩ + 1/2|10⟩ + 1/2|11⟩" + # A compound (cartesian) amplitude is parenthesized before the ket. + binomial = Statevec([0.5 + 0.25j, (1 - abs(0.5 + 0.25j) ** 2) ** 0.5]) + assert binomial.draw(OutputFormat.Unicode) == "(1/2 + 1/4i)|0⟩ + √11/4|1⟩" + # A unit negative amplitude collapses to a bare `-|ket⟩`. + assert Statevec([-1.0 + 0j, 0j]).draw(OutputFormat.Unicode) == "-|0⟩" From 60b60201e327daaefec2ea0d23c7049724fc8155 Mon Sep 17 00:00:00 2001 From: Vinny010 Date: Thu, 4 Jun 2026 11:26:05 +0400 Subject: [PATCH 3/9] Fix Statevec construction in test for Python 3.10 compatibility Build the complex-amplitude statevec from a numpy array (numpy.complex128) rather than a Python complex literal: Python's complex only gained __complex__ in 3.11, so a bare complex is not a typing.SupportsComplex on 3.10 and is rejected by Statevec. Also use a real amplitude for the unit-negative case. Co-Authored-By: Claude Opus 4.8 --- tests/test_pretty_print.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 204b29fb2..087175276 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import networkx as nx +import numpy as np import pytest from numpy.random import PCG64, Generator @@ -309,8 +310,10 @@ def test_statevec_draw_negative_and_parenthesized() -> None: # Negative amplitudes use a `-` separator between terms. neg = Statevec([0.5, -0.5, 0.5, 0.5]) assert neg.draw(OutputFormat.Unicode) == "1/2|00⟩ - 1/2|01⟩ + 1/2|10⟩ + 1/2|11⟩" - # A compound (cartesian) amplitude is parenthesized before the ket. - binomial = Statevec([0.5 + 0.25j, (1 - abs(0.5 + 0.25j) ** 2) ** 0.5]) + # A compound (cartesian) amplitude is parenthesized before the ket. Build from a + # numpy array so the amplitudes are ``numpy.complex128`` (Python's ``complex`` only + # gained ``__complex__`` in 3.11, so a bare ``complex`` is rejected on 3.10). + binomial = Statevec(np.array([0.5 + 0.25j, (1 - abs(0.5 + 0.25j) ** 2) ** 0.5])) assert binomial.draw(OutputFormat.Unicode) == "(1/2 + 1/4i)|0⟩ + √11/4|1⟩" # A unit negative amplitude collapses to a bare `-|ket⟩`. - assert Statevec([-1.0 + 0j, 0j]).draw(OutputFormat.Unicode) == "-|0⟩" + assert Statevec([-1.0, 0.0]).draw(OutputFormat.Unicode) == "-|0⟩" From eebb1048b5d103f5c77dba4ec391cbd69d45b654 Mon Sep 17 00:00:00 2001 From: Vinny010 Date: Fri, 5 Jun 2026 21:50:07 +0400 Subject: [PATCH 4/9] Address review on pretty-print - Add a ``precision`` keyword argument to :func:`complex_to_str`, :func:`statevec_to_str`, :func:`density_matrix_to_str` and the matching :meth:`Statevec.draw` / :meth:`DensityMatrix.draw` methods to control the number of significant digits of the decimal fallback. The default value (``4``) preserves the previous behaviour, and a regression test exercises multiple precisions. - Rename ``_recognize_real`` to ``_recognize_sqrt`` to better reflect the ``signed_num * sqrt(inner) / den`` form the helper returns (pure rationals are covered as ``inner == 1``); update the docstring accordingly. Co-Authored-By: Claude Opus 4.8 --- graphix/pretty_print.py | 50 ++++++++++++++++++++++++----------- graphix/sim/density_matrix.py | 6 ++++- graphix/sim/statevec.py | 14 +++++++++- tests/test_pretty_print.py | 8 ++++++ 4 files changed, 60 insertions(+), 18 deletions(-) diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index bd78f6305..cee5f5e5f 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -453,6 +453,7 @@ def xzcorr_to_str(xzcorr: XZCorrections[AbstractMeasurement], output: OutputForm _DEFAULT_MAX_DENOMINATOR = 1000 _DEFAULT_ATOL = 1e-9 +_DEFAULT_PRECISION = 4 def _squarefree_decomposition(n: int) -> tuple[int, int]: @@ -482,12 +483,12 @@ def _squarefree_decomposition(n: int) -> tuple[int, int]: return outer, inner -def _recognize_real(x: float, max_denominator: int, atol: float) -> tuple[int, int, int] | None: +def _recognize_sqrt(x: float, max_denominator: int, atol: float) -> tuple[int, int, int] | None: """Recognize a real number as ``signed_num * sqrt(inner) / den``. The recognition approximates ``x ** 2`` by a rational ``p / q``; on success, ``x = ±sqrt(p / q)`` is rewritten with a rationalized, fully-reduced - denominator. + denominator. Pure rationals are covered as the special case ``inner == 1``. Parameters ---------- @@ -545,7 +546,7 @@ def _fraction_str(num: str, den: str, output: OutputFormat) -> str: def _render_real(signed_num: int, inner: int, den: int, output: OutputFormat) -> str: - """Render ``signed_num * sqrt(inner) / den`` produced by :func:`_recognize_real`.""" + """Render ``signed_num * sqrt(inner) / den`` produced by :func:`_recognize_sqrt`.""" if signed_num == 0: return "0" sign = "-" if signed_num < 0 else "" @@ -563,7 +564,7 @@ def _render_real(signed_num: int, inner: int, den: int, output: OutputFormat) -> def _real_to_str(x: float, output: OutputFormat, max_denominator: int, atol: float) -> str | None: - rec = _recognize_real(x, max_denominator, atol) + rec = _recognize_sqrt(x, max_denominator, atol) if rec is None: return None return _render_real(*rec, output) @@ -571,7 +572,7 @@ def _real_to_str(x: float, output: OutputFormat, max_denominator: int, atol: flo def _imaginary_to_str(x: float, output: OutputFormat, max_denominator: int, atol: float) -> str | None: """Render a purely imaginary value ``x * i``.""" - rec = _recognize_real(x, max_denominator, atol) + rec = _recognize_sqrt(x, max_denominator, atol) if rec is None: return None signed_num, inner, den = rec @@ -616,7 +617,7 @@ def _exponential_to_str(z: complex, output: OutputFormat, max_denominator: int, def _cartesian_to_str(re: float, im: float, output: OutputFormat, max_denominator: int, atol: float) -> str | None: """Render ``re + im i`` when both parts are recognized as nice reals.""" re_str = _real_to_str(re, output, max_denominator, atol) - im_rec = _recognize_real(im, max_denominator, atol) + im_rec = _recognize_sqrt(im, max_denominator, atol) if re_str is None or im_rec is None: return None signed_num, inner, den = im_rec @@ -629,14 +630,14 @@ def _cartesian_to_str(re: float, im: float, output: OutputFormat, max_denominato return f"{re_str}{connector}{imag}" -def _decimal_to_str(z: complex, output: OutputFormat) -> str: - """Fallback formatting using rounded decimals.""" +def _decimal_to_str(z: complex, output: OutputFormat, precision: int) -> str: + """Fallback formatting using rounded decimals with ``precision`` significant digits.""" unit = _imaginary_unit(output) if abs(z.imag) <= _DEFAULT_ATOL: - return f"{z.real:.4g}" + return f"{z.real:.{precision}g}" if abs(z.real) <= _DEFAULT_ATOL: - return f"{z.imag:.4g}{unit}" - return f"{z.real:.4g}{z.imag:+.4g}{unit}" + return f"{z.imag:.{precision}g}{unit}" + return f"{z.real:.{precision}g}{z.imag:+.{precision}g}{unit}" def complex_to_str( @@ -645,6 +646,7 @@ def complex_to_str( *, max_denominator: int = _DEFAULT_MAX_DENOMINATOR, atol: float = _DEFAULT_ATOL, + precision: int = _DEFAULT_PRECISION, ) -> str: r"""Return a human-friendly string representation of a complex number. @@ -668,6 +670,9 @@ def complex_to_str( (default: ``1000``). atol : float, optional Absolute tolerance for the recognition heuristics (default: ``1e-9``). + precision : int, optional + Number of significant digits to use for the decimal fallback when a + value is not recognized as an exact form (default: ``4``). Returns ------- @@ -682,6 +687,8 @@ def complex_to_str( '√2/2' >>> complex_to_str(0.5 + 0.8660254037844386j, OutputFormat.Unicode) 'e^(iπ/3)' + >>> complex_to_str(0.123456 + 0.234567j, OutputFormat.ASCII, precision=2) + '0.12+0.23i' """ if not isinstance(value, (bool, int, float, complex, SupportsComplex)): return str(value) @@ -689,16 +696,16 @@ def complex_to_str( if abs(z.real) <= atol and abs(z.imag) <= atol: return "0" if abs(z.imag) <= atol: - return _real_to_str(z.real, output, max_denominator, atol) or _decimal_to_str(z, output) + return _real_to_str(z.real, output, max_denominator, atol) or _decimal_to_str(z, output, precision) if abs(z.real) <= atol: - return _imaginary_to_str(z.imag, output, max_denominator, atol) or _decimal_to_str(z, output) + return _imaginary_to_str(z.imag, output, max_denominator, atol) or _decimal_to_str(z, output, precision) exponential = _exponential_to_str(z, output, max_denominator, atol) if exponential is not None: return exponential cartesian = _cartesian_to_str(z.real, z.imag, output, max_denominator, atol) if cartesian is not None: return cartesian - return _decimal_to_str(z, output) + return _decimal_to_str(z, output, precision) def _ket_str(ket: str, output: OutputFormat) -> str: @@ -722,6 +729,7 @@ def statevec_to_str( max_denominator: int = _DEFAULT_MAX_DENOMINATOR, atol: float = _DEFAULT_ATOL, rtol: float = 0.0, + precision: int = _DEFAULT_PRECISION, ) -> str: r"""Return a ket-notation string representation of a statevector. @@ -744,6 +752,9 @@ def statevec_to_str( recognition heuristics (default: ``1e-9``). rtol : float, optional Relative tolerance used to drop near-zero amplitudes (default: ``0.0``). + precision : int, optional + Number of significant digits to use for amplitudes that fall back to a + decimal representation (default: ``4``). Returns ------- @@ -755,7 +766,7 @@ def statevec_to_str( return "0" result = "" for index, (ket, amplitude) in enumerate(amplitudes.items()): - coefficient = complex_to_str(amplitude, output, max_denominator=max_denominator, atol=atol) + coefficient = complex_to_str(amplitude, output, max_denominator=max_denominator, atol=atol, precision=precision) ket_str = _ket_str(ket, output) if coefficient == "1": term = ket_str @@ -780,6 +791,7 @@ def density_matrix_to_str( *, max_denominator: int = _DEFAULT_MAX_DENOMINATOR, atol: float = _DEFAULT_ATOL, + precision: int = _DEFAULT_PRECISION, ) -> str: r"""Return a matrix-form string representation of a density matrix. @@ -797,6 +809,9 @@ def density_matrix_to_str( Maximum denominator used by the entry recognition (default: ``1000``). atol : float, optional Absolute tolerance for the recognition heuristics (default: ``1e-9``). + precision : int, optional + Number of significant digits to use for entries that fall back to a + decimal representation (default: ``4``). Returns ------- @@ -804,7 +819,10 @@ def density_matrix_to_str( The formatted density matrix. """ rows = [ - [complex_to_str(entry, output, max_denominator=max_denominator, atol=atol) for entry in row] + [ + complex_to_str(entry, output, max_denominator=max_denominator, atol=atol, precision=precision) + for entry in row + ] for row in density_matrix.rho ] if output == OutputFormat.LaTeX: diff --git a/graphix/sim/density_matrix.py b/graphix/sim/density_matrix.py index 0f0af891b..5b5a8f283 100644 --- a/graphix/sim/density_matrix.py +++ b/graphix/sim/density_matrix.py @@ -124,6 +124,7 @@ def draw( *, max_denominator: int = 1000, atol: float = 1e-9, + precision: int = 4, ) -> str: r"""Return a pretty-printed matrix representation of the density matrix. @@ -139,13 +140,16 @@ def draw( Maximum denominator used by the entry recognition (default: ``1000``). atol : float, optional Absolute tolerance for the recognition heuristics (default: ``1e-9``). + precision : int, optional + Number of significant digits to use for entries that fall back to a + decimal representation (default: ``4``). Returns ------- str The formatted density matrix. """ - return density_matrix_to_str(self, output, max_denominator=max_denominator, atol=atol) + return density_matrix_to_str(self, output, max_denominator=max_denominator, atol=atol, precision=precision) @override def add_nodes(self, nqubit: int, data: Data) -> None: diff --git a/graphix/sim/statevec.py b/graphix/sim/statevec.py index 9a6d66749..cdfafeca5 100644 --- a/graphix/sim/statevec.py +++ b/graphix/sim/statevec.py @@ -529,6 +529,7 @@ def draw( max_denominator: int = 1000, atol: float = 1e-9, rtol: float = 0.0, + precision: int = 4, ) -> str: r"""Return a pretty-printed ket-notation representation of the statevector. @@ -550,6 +551,9 @@ def draw( recognition heuristics (default: ``1e-9``). rtol : float, optional Relative tolerance for dropping near-zero amplitudes (default: ``0.0``). + precision : int, optional + Number of significant digits to use for amplitudes that fall back to + a decimal representation (default: ``4``). Returns ------- @@ -565,7 +569,15 @@ def draw( >>> print(circuit.simulate_statevector().statevec.draw()) √2/2|00⟩ + √2/2|01⟩ """ - return statevec_to_str(self, output, encoding=encoding, max_denominator=max_denominator, atol=atol, rtol=rtol) + return statevec_to_str( + self, + output, + encoding=encoding, + max_denominator=max_denominator, + atol=atol, + rtol=rtol, + precision=precision, + ) def _to_dict_map( self, diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 087175276..587313d9e 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -317,3 +317,11 @@ def test_statevec_draw_negative_and_parenthesized() -> None: assert binomial.draw(OutputFormat.Unicode) == "(1/2 + 1/4i)|0⟩ + √11/4|1⟩" # A unit negative amplitude collapses to a bare `-|ket⟩`. assert Statevec([-1.0, 0.0]).draw(OutputFormat.Unicode) == "-|0⟩" + + +def test_complex_to_str_precision_is_configurable() -> None: + z = 0.123456 + 0.234567j + assert complex_to_str(z, OutputFormat.ASCII, precision=2) == "0.12+0.23i" + assert complex_to_str(z, OutputFormat.ASCII, precision=6) == "0.123456+0.234567i" + # The default keeps the previous behaviour (four significant digits). + assert complex_to_str(z, OutputFormat.ASCII) == "0.1235+0.2346i" From d3f8a560866a4014f98da849ecf70fd7513cc57d Mon Sep 17 00:00:00 2001 From: Vinny010 Date: Wed, 10 Jun 2026 03:11:19 +0400 Subject: [PATCH 5/9] Address review: ket macro, i-first imaginary, Cartesian preference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use \ket{} in LaTeX kets instead of \lvert ... \rangle - Render pure imaginary values with the unit leading the numerator (i/2, -i√2/2) via a dedicated _render_imaginary helper - Prefer the Cartesian form over a radius-prefixed exponential when |z| != 1 (1 + i instead of √2·e^(iπ/4)); keep the radius-prefixed exponential only as a last resort before the decimal fallback - Merge the format-specific complex_to_str tests into a single parametrized test keyed by a Mapping[OutputFormat, str] Co-Authored-By: Claude Opus 4.8 --- graphix/pretty_print.py | 34 ++++++++++---- tests/test_pretty_print.py | 92 +++++++++++++++++++------------------- 2 files changed, 72 insertions(+), 54 deletions(-) diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index cee5f5e5f..de3a56bd4 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -570,17 +570,29 @@ def _real_to_str(x: float, output: OutputFormat, max_denominator: int, atol: flo return _render_real(*rec, output) +def _render_imaginary(signed_num: int, inner: int, den: int, output: OutputFormat) -> str: + """Render ``signed_num * sqrt(inner) / den * i`` with the unit leading the numerator. + + The imaginary unit is placed at the front of the numerator (e.g. ``i/2``, + ``i√2/2``, ``3i/4``) so that it reads as ``i`` times a real magnitude, with a + unit coefficient collapsing to a bare ``±i``. + """ + sign = "-" if signed_num < 0 else "" + magnitude = abs(signed_num) + unit = _imaginary_unit(output) + coefficient = "" if magnitude == 1 else f"{magnitude}" + numerator = f"{coefficient}{unit}{_sqrt_str(inner, output)}" + if den == 1: + return f"{sign}{numerator}" + return f"{sign}{_fraction_str(numerator, str(den), output)}" + + def _imaginary_to_str(x: float, output: OutputFormat, max_denominator: int, atol: float) -> str | None: """Render a purely imaginary value ``x * i``.""" rec = _recognize_sqrt(x, max_denominator, atol) if rec is None: return None - signed_num, inner, den = rec - unit = _imaginary_unit(output) - # A unit coefficient collapses to just ``±i``. - if abs(signed_num) == 1 and inner == 1 and den == 1: - return f"{'-' if signed_num < 0 else ''}{unit}" - return f"{_render_real(signed_num, inner, den, output)}{unit}" + return _render_imaginary(*rec, output) def _recognize_angle_over_pi(theta: float, max_denominator: int, atol: float) -> Fraction | None: @@ -700,17 +712,23 @@ def complex_to_str( if abs(z.real) <= atol: return _imaginary_to_str(z.imag, output, max_denominator, atol) or _decimal_to_str(z, output, precision) exponential = _exponential_to_str(z, output, max_denominator, atol) - if exponential is not None: + # The exponential form is the clearest representation on the unit circle, but for + # other moduli a Cartesian form (e.g. ``1 + i`` rather than ``√2·e^(iπ/4)``) reads + # better, so it is preferred when both parts are recognized. + if exponential is not None and math.isclose(math.hypot(z.real, z.imag), 1.0, abs_tol=atol): return exponential cartesian = _cartesian_to_str(z.real, z.imag, output, max_denominator, atol) if cartesian is not None: return cartesian + # Fall back to a radius-prefixed exponential when the Cartesian parts are not nice. + if exponential is not None: + return exponential return _decimal_to_str(z, output, precision) def _ket_str(ket: str, output: OutputFormat) -> str: if output == OutputFormat.LaTeX: - return rf"\lvert {ket}\rangle" + return rf"\ket{{{ket}}}" if output == OutputFormat.Unicode: return f"|{ket}⟩" return f"|{ket}>" diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 587313d9e..470865640 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -23,7 +23,7 @@ from graphix.transpiler import Circuit if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Mapping from graphix.flow.core import PauliFlow @@ -219,44 +219,42 @@ def test_complex_to_str_issue_examples() -> None: @pytest.mark.parametrize( ("value", "expected"), [ - (0, "0"), - (1e-12, "0"), - (1, "1"), - (-1, "-1"), - (2, "2"), - (0.5, "1/2"), - (-0.25, "-1/4"), - (2**-0.5, "√2/2"), - (math.sqrt(3) / 2, "√3/2"), - (1j, "i"), - (-1j, "-i"), - (0.5j, "1/2i"), - (-(2**-0.5) * 1j, "-√2/2i"), + (0, {OutputFormat.Unicode: "0"}), + (1e-12, {OutputFormat.Unicode: "0"}), + (1, {OutputFormat.Unicode: "1"}), + (-1, {OutputFormat.Unicode: "-1"}), + (2, {OutputFormat.Unicode: "2"}), + (0.5, {OutputFormat.Unicode: "1/2"}), + (0.25, {OutputFormat.LaTeX: r"\frac{1}{4}"}), + (-0.25, {OutputFormat.Unicode: "-1/4"}), + (2**-0.5, {OutputFormat.Unicode: "√2/2", OutputFormat.LaTeX: r"\frac{\sqrt{2}}{2}"}), + (math.sqrt(3) / 2, {OutputFormat.Unicode: "√3/2"}), + (1j, {OutputFormat.Unicode: "i"}), + (-1j, {OutputFormat.Unicode: "-i"}), + # The imaginary unit leads the numerator (i/2, not 1/2i). + (0.5j, {OutputFormat.Unicode: "i/2", OutputFormat.ASCII: "i/2", OutputFormat.LaTeX: r"\frac{\mathrm{i}}{2}"}), + (-(2**-0.5) * 1j, {OutputFormat.Unicode: "-i√2/2"}), + # Complex exponentials on the unit circle. + (math.cos(math.pi / 4) + math.sin(math.pi / 4) * 1j, {OutputFormat.Unicode: "e^(iπ/4)"}), + # Negative phase keeps the sign inside the exponent. + (math.cos(math.pi / 3) - math.sin(math.pi / 3) * 1j, {OutputFormat.Unicode: "e^(-iπ/3)"}), + ( + 0.5 + math.sqrt(3) / 2 * 1j, + { + OutputFormat.Unicode: "e^(iπ/3)", + OutputFormat.ASCII: "e^(i*pi/3)", + OutputFormat.LaTeX: r"\mathrm{e}^{\mathrm{i} \frac{\pi}{3}}", + }, + ), + # An unrecognized value falls back to a rounded decimal. + (0.123456, {OutputFormat.ASCII: "0.1235"}), + # A non-numeric object is stringified rather than raising. + ("alpha", {OutputFormat.ASCII: "alpha"}), ], ) -def test_complex_to_str_unicode_values(value: complex, expected: str) -> None: - assert complex_to_str(value, OutputFormat.Unicode) == expected - - -def test_complex_to_str_exponentials() -> None: - assert complex_to_str(1j, OutputFormat.Unicode) == "i" - assert complex_to_str(math.cos(math.pi / 4) + math.sin(math.pi / 4) * 1j, OutputFormat.Unicode) == "e^(iπ/4)" - # Negative phase keeps the sign inside the exponent. - assert complex_to_str(math.cos(math.pi / 3) - math.sin(math.pi / 3) * 1j, OutputFormat.Unicode) == "e^(-iπ/3)" - assert complex_to_str(0.5 + math.sqrt(3) / 2 * 1j, OutputFormat.ASCII) == "e^(i*pi/3)" - - -def test_complex_to_str_latex() -> None: - assert complex_to_str(2**-0.5, OutputFormat.LaTeX) == r"\frac{\sqrt{2}}{2}" - assert complex_to_str(0.25, OutputFormat.LaTeX) == r"\frac{1}{4}" - assert complex_to_str(0.5 + math.sqrt(3) / 2 * 1j, OutputFormat.LaTeX) == r"\mathrm{e}^{\mathrm{i} \frac{\pi}{3}}" - - -def test_complex_to_str_fallback_and_symbolic() -> None: - # An unrecognized value falls back to a rounded decimal. - assert complex_to_str(0.123456, OutputFormat.ASCII) == "0.1235" - # A non-numeric object is stringified rather than raising. - assert complex_to_str("alpha", OutputFormat.ASCII) == "alpha" +def test_complex_to_str_values(value: object, expected: Mapping[OutputFormat, str]) -> None: + for output, text in expected.items(): + assert complex_to_str(value, output) == text def test_statevec_draw() -> None: @@ -279,10 +277,17 @@ def test_density_matrix_draw() -> None: def test_complex_to_str_exponential_with_radius() -> None: - # |z| != 1: the radius prefixes the exponential form (1 + i = √2 e^{iπ/4}). - assert complex_to_str(1 + 1j, OutputFormat.Unicode) == "√2·e^(iπ/4)" - assert complex_to_str(1 + 1j, OutputFormat.ASCII) == "sqrt(2)*e^(i*pi/4)" - assert complex_to_str(1 + 1j, OutputFormat.LaTeX) == r"\sqrt{2} \mathrm{e}^{\mathrm{i} \frac{\pi}{4}}" + # |z| != 1 with nice Cartesian parts: the Cartesian form is preferred over the + # radius-prefixed exponential (1 + i rather than √2·e^(iπ/4)). + assert complex_to_str(1 + 1j, OutputFormat.Unicode) == "1 + i" + assert complex_to_str(1 + 1j, OutputFormat.ASCII) == "1 + i" + assert complex_to_str(1 + 1j, OutputFormat.LaTeX) == r"1 + \mathrm{i}" + # When the Cartesian parts are not recognized, the radius-prefixed exponential is + # used as a last resort before the decimal fallback. + z = 2 * (math.cos(math.pi / 5) + math.sin(math.pi / 5) * 1j) + assert complex_to_str(z, OutputFormat.Unicode) == "2·e^(iπ/5)" + assert complex_to_str(z, OutputFormat.ASCII) == "2*e^(i*pi/5)" + assert complex_to_str(z, OutputFormat.LaTeX) == r"2 \mathrm{e}^{\mathrm{i} \frac{\pi}{5}}" def test_complex_to_str_cartesian_form() -> None: @@ -297,11 +302,6 @@ def test_complex_to_str_complex_decimal_fallback() -> None: assert complex_to_str(0.123456 + 0.234567j, OutputFormat.Unicode) == "0.1235+0.2346i" -def test_complex_to_str_imaginary_formats() -> None: - assert complex_to_str(0.5j, OutputFormat.LaTeX) == r"\frac{1}{2}\mathrm{i}" - assert complex_to_str(0.5j, OutputFormat.ASCII) == "1/2i" - - def test_complex_to_str_integer_times_sqrt() -> None: assert complex_to_str(math.sqrt(12), OutputFormat.Unicode) == "2√3" From 58eb451ddc2f65fe0bfb898bda64f86b985badc8 Mon Sep 17 00:00:00 2001 From: Vinny010 Date: Fri, 12 Jun 2026 01:44:42 +0400 Subject: [PATCH 6/9] Address review: rtol parameter, atol threading, test coverage Responds to @thierry-martinez's review on #524: - Add an `rtol` parameter everywhere an `atol` is exposed: `complex_to_str`, `density_matrix_to_str`/`DensityMatrix.draw` (it was missing), and threaded through the recognition helpers into the `math.isclose` calls. The default (0.0) preserves the previous behaviour. - `_decimal_to_str` now takes `atol` as an argument instead of relying on the module-level `_DEFAULT_ATOL`. - Test coverage: - Merge the remaining `complex_to_str` value tests (radius/Cartesian/decimal fallback/integer-times-surd) into the parametrized `test_complex_to_str_values`. - Cover the LaTeX `\ket{...}` notation in the statevec draw tests (was untested). - Add tests exercising `rtol` (recognition control + `DensityMatrix.draw`). Co-Authored-By: Claude Opus 4.8 --- graphix/pretty_print.py | 68 ++++++++++++++++++++------------- graphix/sim/density_matrix.py | 7 +++- tests/test_pretty_print.py | 71 ++++++++++++++++++++--------------- 3 files changed, 89 insertions(+), 57 deletions(-) diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index de3a56bd4..072051e6f 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -453,6 +453,7 @@ def xzcorr_to_str(xzcorr: XZCorrections[AbstractMeasurement], output: OutputForm _DEFAULT_MAX_DENOMINATOR = 1000 _DEFAULT_ATOL = 1e-9 +_DEFAULT_RTOL = 0.0 _DEFAULT_PRECISION = 4 @@ -483,7 +484,7 @@ def _squarefree_decomposition(n: int) -> tuple[int, int]: return outer, inner -def _recognize_sqrt(x: float, max_denominator: int, atol: float) -> tuple[int, int, int] | None: +def _recognize_sqrt(x: float, max_denominator: int, atol: float, rtol: float) -> tuple[int, int, int] | None: """Recognize a real number as ``signed_num * sqrt(inner) / den``. The recognition approximates ``x ** 2`` by a rational ``p / q``; on success, @@ -498,6 +499,8 @@ def _recognize_sqrt(x: float, max_denominator: int, atol: float) -> tuple[int, i Maximum denominator allowed when approximating ``x ** 2`` by a rational. atol : float Absolute tolerance for that rational approximation. + rtol : float + Relative tolerance for that rational approximation. Returns ------- @@ -509,7 +512,7 @@ def _recognize_sqrt(x: float, max_denominator: int, atol: float) -> tuple[int, i if x == 0: return 0, 1, 1 square = Fraction(x * x).limit_denominator(max_denominator) - if not math.isclose(x * x, float(square), abs_tol=atol): + if not math.isclose(x * x, float(square), rel_tol=rtol, abs_tol=atol): return None num_outer, num_inner = _squarefree_decomposition(square.numerator) den_outer, den_inner = _squarefree_decomposition(square.denominator) @@ -563,8 +566,8 @@ def _render_real(signed_num: int, inner: int, den: int, output: OutputFormat) -> return f"{sign}{_fraction_str(numerator, str(den), output)}" -def _real_to_str(x: float, output: OutputFormat, max_denominator: int, atol: float) -> str | None: - rec = _recognize_sqrt(x, max_denominator, atol) +def _real_to_str(x: float, output: OutputFormat, max_denominator: int, atol: float, rtol: float) -> str | None: + rec = _recognize_sqrt(x, max_denominator, atol, rtol) if rec is None: return None return _render_real(*rec, output) @@ -587,30 +590,30 @@ def _render_imaginary(signed_num: int, inner: int, den: int, output: OutputForma return f"{sign}{_fraction_str(numerator, str(den), output)}" -def _imaginary_to_str(x: float, output: OutputFormat, max_denominator: int, atol: float) -> str | None: +def _imaginary_to_str(x: float, output: OutputFormat, max_denominator: int, atol: float, rtol: float) -> str | None: """Render a purely imaginary value ``x * i``.""" - rec = _recognize_sqrt(x, max_denominator, atol) + rec = _recognize_sqrt(x, max_denominator, atol, rtol) if rec is None: return None return _render_imaginary(*rec, output) -def _recognize_angle_over_pi(theta: float, max_denominator: int, atol: float) -> Fraction | None: +def _recognize_angle_over_pi(theta: float, max_denominator: int, atol: float, rtol: float) -> Fraction | None: """Return ``theta / pi`` as a simple fraction, or ``None`` if it is not one.""" value = theta / pi frac = Fraction(value).limit_denominator(max_denominator) - if math.isclose(value, float(frac), abs_tol=atol): + if math.isclose(value, float(frac), rel_tol=rtol, abs_tol=atol): return frac return None -def _exponential_to_str(z: complex, output: OutputFormat, max_denominator: int, atol: float) -> str | None: +def _exponential_to_str(z: complex, output: OutputFormat, max_denominator: int, atol: float, rtol: float) -> str | None: """Render ``z`` as ``r e^{iθ}`` when both ``r`` and ``θ / π`` are recognized.""" theta = math.atan2(z.imag, z.real) - angle_frac = _recognize_angle_over_pi(theta, max_denominator, atol) + angle_frac = _recognize_angle_over_pi(theta, max_denominator, atol, rtol) if angle_frac is None or angle_frac == 0: return None - radius = _real_to_str(math.hypot(z.real, z.imag), output, max_denominator, atol) + radius = _real_to_str(math.hypot(z.real, z.imag), output, max_denominator, atol, rtol) if radius is None: return None sign = "-" if angle_frac < 0 else "" @@ -626,10 +629,12 @@ def _exponential_to_str(z: complex, output: OutputFormat, max_denominator: int, return f"{radius}{prefix_sep}{body}" -def _cartesian_to_str(re: float, im: float, output: OutputFormat, max_denominator: int, atol: float) -> str | None: +def _cartesian_to_str( + re: float, im: float, output: OutputFormat, max_denominator: int, atol: float, rtol: float +) -> str | None: """Render ``re + im i`` when both parts are recognized as nice reals.""" - re_str = _real_to_str(re, output, max_denominator, atol) - im_rec = _recognize_sqrt(im, max_denominator, atol) + re_str = _real_to_str(re, output, max_denominator, atol, rtol) + im_rec = _recognize_sqrt(im, max_denominator, atol, rtol) if re_str is None or im_rec is None: return None signed_num, inner, den = im_rec @@ -642,12 +647,12 @@ def _cartesian_to_str(re: float, im: float, output: OutputFormat, max_denominato return f"{re_str}{connector}{imag}" -def _decimal_to_str(z: complex, output: OutputFormat, precision: int) -> str: +def _decimal_to_str(z: complex, output: OutputFormat, precision: int, atol: float) -> str: """Fallback formatting using rounded decimals with ``precision`` significant digits.""" unit = _imaginary_unit(output) - if abs(z.imag) <= _DEFAULT_ATOL: + if abs(z.imag) <= atol: return f"{z.real:.{precision}g}" - if abs(z.real) <= _DEFAULT_ATOL: + if abs(z.real) <= atol: return f"{z.imag:.{precision}g}{unit}" return f"{z.real:.{precision}g}{z.imag:+.{precision}g}{unit}" @@ -658,6 +663,7 @@ def complex_to_str( *, max_denominator: int = _DEFAULT_MAX_DENOMINATOR, atol: float = _DEFAULT_ATOL, + rtol: float = _DEFAULT_RTOL, precision: int = _DEFAULT_PRECISION, ) -> str: r"""Return a human-friendly string representation of a complex number. @@ -682,6 +688,8 @@ def complex_to_str( (default: ``1000``). atol : float, optional Absolute tolerance for the recognition heuristics (default: ``1e-9``). + rtol : float, optional + Relative tolerance for the recognition heuristics (default: ``0.0``). precision : int, optional Number of significant digits to use for the decimal fallback when a value is not recognized as an exact form (default: ``4``). @@ -708,22 +716,24 @@ def complex_to_str( if abs(z.real) <= atol and abs(z.imag) <= atol: return "0" if abs(z.imag) <= atol: - return _real_to_str(z.real, output, max_denominator, atol) or _decimal_to_str(z, output, precision) + return _real_to_str(z.real, output, max_denominator, atol, rtol) or _decimal_to_str(z, output, precision, atol) if abs(z.real) <= atol: - return _imaginary_to_str(z.imag, output, max_denominator, atol) or _decimal_to_str(z, output, precision) - exponential = _exponential_to_str(z, output, max_denominator, atol) + return _imaginary_to_str(z.imag, output, max_denominator, atol, rtol) or _decimal_to_str( + z, output, precision, atol + ) + exponential = _exponential_to_str(z, output, max_denominator, atol, rtol) # The exponential form is the clearest representation on the unit circle, but for # other moduli a Cartesian form (e.g. ``1 + i`` rather than ``√2·e^(iπ/4)``) reads # better, so it is preferred when both parts are recognized. - if exponential is not None and math.isclose(math.hypot(z.real, z.imag), 1.0, abs_tol=atol): + if exponential is not None and math.isclose(math.hypot(z.real, z.imag), 1.0, rel_tol=rtol, abs_tol=atol): return exponential - cartesian = _cartesian_to_str(z.real, z.imag, output, max_denominator, atol) + cartesian = _cartesian_to_str(z.real, z.imag, output, max_denominator, atol, rtol) if cartesian is not None: return cartesian # Fall back to a radius-prefixed exponential when the Cartesian parts are not nice. if exponential is not None: return exponential - return _decimal_to_str(z, output, precision) + return _decimal_to_str(z, output, precision, atol) def _ket_str(ket: str, output: OutputFormat) -> str: @@ -769,7 +779,8 @@ def statevec_to_str( Absolute tolerance used both to drop near-zero amplitudes and for the recognition heuristics (default: ``1e-9``). rtol : float, optional - Relative tolerance used to drop near-zero amplitudes (default: ``0.0``). + Relative tolerance used both to drop near-zero amplitudes and for the + recognition heuristics (default: ``0.0``). precision : int, optional Number of significant digits to use for amplitudes that fall back to a decimal representation (default: ``4``). @@ -784,7 +795,9 @@ def statevec_to_str( return "0" result = "" for index, (ket, amplitude) in enumerate(amplitudes.items()): - coefficient = complex_to_str(amplitude, output, max_denominator=max_denominator, atol=atol, precision=precision) + coefficient = complex_to_str( + amplitude, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision + ) ket_str = _ket_str(ket, output) if coefficient == "1": term = ket_str @@ -809,6 +822,7 @@ def density_matrix_to_str( *, max_denominator: int = _DEFAULT_MAX_DENOMINATOR, atol: float = _DEFAULT_ATOL, + rtol: float = _DEFAULT_RTOL, precision: int = _DEFAULT_PRECISION, ) -> str: r"""Return a matrix-form string representation of a density matrix. @@ -827,6 +841,8 @@ def density_matrix_to_str( Maximum denominator used by the entry recognition (default: ``1000``). atol : float, optional Absolute tolerance for the recognition heuristics (default: ``1e-9``). + rtol : float, optional + Relative tolerance for the recognition heuristics (default: ``0.0``). precision : int, optional Number of significant digits to use for entries that fall back to a decimal representation (default: ``4``). @@ -838,7 +854,7 @@ def density_matrix_to_str( """ rows = [ [ - complex_to_str(entry, output, max_denominator=max_denominator, atol=atol, precision=precision) + complex_to_str(entry, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision) for entry in row ] for row in density_matrix.rho diff --git a/graphix/sim/density_matrix.py b/graphix/sim/density_matrix.py index 5b5a8f283..8c6ef8331 100644 --- a/graphix/sim/density_matrix.py +++ b/graphix/sim/density_matrix.py @@ -124,6 +124,7 @@ def draw( *, max_denominator: int = 1000, atol: float = 1e-9, + rtol: float = 0.0, precision: int = 4, ) -> str: r"""Return a pretty-printed matrix representation of the density matrix. @@ -140,6 +141,8 @@ def draw( Maximum denominator used by the entry recognition (default: ``1000``). atol : float, optional Absolute tolerance for the recognition heuristics (default: ``1e-9``). + rtol : float, optional + Relative tolerance for the recognition heuristics (default: ``0.0``). precision : int, optional Number of significant digits to use for entries that fall back to a decimal representation (default: ``4``). @@ -149,7 +152,9 @@ def draw( str The formatted density matrix. """ - return density_matrix_to_str(self, output, max_denominator=max_denominator, atol=atol, precision=precision) + return density_matrix_to_str( + self, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision + ) @override def add_nodes(self, nqubit: int, data: Data) -> None: diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 470865640..d44f45479 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -250,6 +250,28 @@ def test_complex_to_str_issue_examples() -> None: (0.123456, {OutputFormat.ASCII: "0.1235"}), # A non-numeric object is stringified rather than raising. ("alpha", {OutputFormat.ASCII: "alpha"}), + # |z| != 1 with nice Cartesian parts: the Cartesian form is preferred over the + # radius-prefixed exponential (1 + i rather than √2·e^(iπ/4)). + (1 + 1j, {OutputFormat.Unicode: "1 + i", OutputFormat.ASCII: "1 + i", OutputFormat.LaTeX: r"1 + \mathrm{i}"}), + # When the Cartesian parts are not recognized, the radius-prefixed exponential is the + # last resort before the decimal fallback. + ( + 2 * (math.cos(math.pi / 5) + math.sin(math.pi / 5) * 1j), + { + OutputFormat.Unicode: "2·e^(iπ/5)", + OutputFormat.ASCII: "2*e^(i*pi/5)", + OutputFormat.LaTeX: r"2 \mathrm{e}^{\mathrm{i} \frac{\pi}{5}}", + }, + ), + # Both parts recognized but the phase is not a simple fraction of π: Cartesian form. + ( + 0.5 + 0.25j, + {OutputFormat.Unicode: "1/2 + 1/4i", OutputFormat.LaTeX: r"\frac{1}{2} + \frac{1}{4}\mathrm{i}"}, + ), + # Neither part recognized -> rounded decimal real and imaginary parts. + (0.123456 + 0.234567j, {OutputFormat.Unicode: "0.1235+0.2346i"}), + # Integer multiple of a surd. + (math.sqrt(12), {OutputFormat.Unicode: "2√3"}), ], ) def test_complex_to_str_values(value: object, expected: Mapping[OutputFormat, str]) -> None: @@ -261,11 +283,15 @@ def test_statevec_draw() -> None: bell = Statevec([2**-0.5, 0, 0, 2**-0.5]) assert bell.draw(OutputFormat.Unicode) == "√2/2|00⟩ + √2/2|11⟩" assert bell.draw(OutputFormat.ASCII) == "sqrt(2)/2|00> + sqrt(2)/2|11>" + # LaTeX uses the \ket{...} macro for the basis kets. + assert bell.draw(OutputFormat.LaTeX) == r"\frac{\sqrt{2}}{2}\ket{00} + \frac{\sqrt{2}}{2}\ket{11}" def test_statevec_draw_single_basis_state() -> None: state = Statevec(data=[BasicStates.ZERO, BasicStates.ONE]) assert state.draw(OutputFormat.Unicode) == "|01⟩" + # LaTeX ket notation for a bare basis state. + assert state.draw(OutputFormat.LaTeX) == r"\ket{01}" # LSB encoding reverses the ket label. assert state.draw(OutputFormat.Unicode, encoding="LSB") == "|10⟩" @@ -276,36 +302,6 @@ def test_density_matrix_draw() -> None: assert dm.draw(OutputFormat.LaTeX) == r"\begin{pmatrix}1 & 0 \\ 0 & 0\end{pmatrix}" -def test_complex_to_str_exponential_with_radius() -> None: - # |z| != 1 with nice Cartesian parts: the Cartesian form is preferred over the - # radius-prefixed exponential (1 + i rather than √2·e^(iπ/4)). - assert complex_to_str(1 + 1j, OutputFormat.Unicode) == "1 + i" - assert complex_to_str(1 + 1j, OutputFormat.ASCII) == "1 + i" - assert complex_to_str(1 + 1j, OutputFormat.LaTeX) == r"1 + \mathrm{i}" - # When the Cartesian parts are not recognized, the radius-prefixed exponential is - # used as a last resort before the decimal fallback. - z = 2 * (math.cos(math.pi / 5) + math.sin(math.pi / 5) * 1j) - assert complex_to_str(z, OutputFormat.Unicode) == "2·e^(iπ/5)" - assert complex_to_str(z, OutputFormat.ASCII) == "2*e^(i*pi/5)" - assert complex_to_str(z, OutputFormat.LaTeX) == r"2 \mathrm{e}^{\mathrm{i} \frac{\pi}{5}}" - - -def test_complex_to_str_cartesian_form() -> None: - # Both parts are recognized but the phase is not a simple fraction of π, so the - # cartesian form is used instead of the exponential one. - assert complex_to_str(0.5 + 0.25j, OutputFormat.Unicode) == "1/2 + 1/4i" - assert complex_to_str(0.5 + 0.25j, OutputFormat.LaTeX) == r"\frac{1}{2} + \frac{1}{4}\mathrm{i}" - - -def test_complex_to_str_complex_decimal_fallback() -> None: - # Neither part is a recognized value -> rounded decimal real and imaginary parts. - assert complex_to_str(0.123456 + 0.234567j, OutputFormat.Unicode) == "0.1235+0.2346i" - - -def test_complex_to_str_integer_times_sqrt() -> None: - assert complex_to_str(math.sqrt(12), OutputFormat.Unicode) == "2√3" - - def test_statevec_draw_negative_and_parenthesized() -> None: # Negative amplitudes use a `-` separator between terms. neg = Statevec([0.5, -0.5, 0.5, 0.5]) @@ -325,3 +321,18 @@ def test_complex_to_str_precision_is_configurable() -> None: assert complex_to_str(z, OutputFormat.ASCII, precision=6) == "0.123456+0.234567i" # The default keeps the previous behaviour (four significant digits). assert complex_to_str(z, OutputFormat.ASCII) == "0.1235+0.2346i" + + +def test_complex_to_str_rtol_controls_recognition() -> None: + # A value slightly off 1/2: with the default (tight) tolerances it is not recognized as a + # fraction and falls back to a decimal; a looser relative tolerance recognizes it as 1/2. + x = 0.500001 + assert complex_to_str(x, OutputFormat.Unicode) == "0.5" + assert complex_to_str(x, OutputFormat.Unicode, rtol=1e-4) == "1/2" + + +def test_density_matrix_draw_rtol() -> None: + # `rtol` is accepted by `DensityMatrix.draw` and threaded to the entry recognition. + dm = DensityMatrix(data=[BasicStates.PLUS]) + assert dm.draw(OutputFormat.Unicode) == "[ 1/2 1/2 ]\n[ 1/2 1/2 ]" + assert dm.draw(OutputFormat.Unicode, rtol=1e-4) == "[ 1/2 1/2 ]\n[ 1/2 1/2 ]" From 5b7c3896e078af51293dbbe52b8d155eb338fc9e Mon Sep 17 00:00:00 2001 From: Vinny010 Date: Fri, 12 Jun 2026 21:45:49 +0400 Subject: [PATCH 7/9] Address review: factor uniform magnitude, cover edge-case branches MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Responds to @thierry-martinez's follow-up review on #524: - Factor a magnitude shared by every amplitude in the ket-notation output, so a state prints as `√2/2(|00⟩ + |11⟩)` rather than `√2/2|00⟩ + √2/2|11⟩` (signs are kept inside the parentheses). Compound (Cartesian) and non-uniform amplitudes fall back to the previous term-by-term rendering. Updated the affected tests and the docstring / doctest examples. - Cover the branches flagged as untested: the tiny-real -> "0" path (square-free decomposition and real-renderer zero branches), the decimal imaginary fallback, the exponential form with an unrecognized radius, and the empty-statevector "0". - Drop the redundant `x == 0` guard in `_recognize_sqrt`: the general path already returns `(0, 1, 1)` for `x == 0`, so the early return was dead code. Co-Authored-By: Claude Opus 4.8 --- graphix/pretty_print.py | 50 ++++++++++++++++++++++++++++++++------ graphix/sim/statevec.py | 2 +- tests/test_pretty_print.py | 35 ++++++++++++++++++++++---- 3 files changed, 73 insertions(+), 14 deletions(-) diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index 072051e6f..4a00c56b5 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -509,8 +509,6 @@ def _recognize_sqrt(x: float, max_denominator: int, atol: float, rtol: float) -> squarefree integer, encoding ``x = signed_num * sqrt(inner) / den``. Returns ``None`` when ``x`` is not recognized as such a value. """ - if x == 0: - return 0, 1, 1 square = Fraction(x * x).limit_denominator(max_denominator) if not math.isclose(x * x, float(square), rel_tol=rtol, abs_tol=atol): return None @@ -749,6 +747,37 @@ def _needs_parentheses(coefficient: str) -> bool: return " + " in coefficient or " - " in coefficient +def _factor_uniform_magnitude(coefficients: Sequence[str], kets: Sequence[str]) -> str | None: + """Factor a magnitude shared by every amplitude, e.g. ``√2/2(|00⟩ + |11⟩)``. + + Returns the factored string when all amplitudes have the same non-unit magnitude + (up to sign), or ``None`` when no such common factor exists (in which case the + caller falls back to the term-by-term rendering). + """ + if len(coefficients) < 2: + return None + magnitudes: list[str] = [] + signs: list[str] = [] + for coefficient in coefficients: + # A compound (Cartesian) coefficient has no single magnitude to factor out. + if _needs_parentheses(coefficient): + return None + if coefficient.startswith("-"): + signs.append("-") + magnitudes.append(coefficient[1:]) + else: + signs.append("+") + magnitudes.append(coefficient) + common = magnitudes[0] + # Nothing to factor when the magnitude is the unit coefficient, or it is not shared. + if common == "1" or any(magnitude != common for magnitude in magnitudes): + return None + body = kets[0] if signs[0] == "+" else f"-{kets[0]}" + for sign, ket in zip(signs[1:], kets[1:], strict=True): + body += f" {sign} {ket}" + return f"{common}({body})" + + def statevec_to_str( statevec: Statevec, output: OutputFormat, @@ -788,17 +817,22 @@ def statevec_to_str( Returns ------- str - The formatted statevector, e.g. ``√2/2|00⟩ + √2/2|01⟩``. + The formatted statevector, e.g. ``√2/2(|00⟩ + |01⟩)``. """ amplitudes = statevec.to_dict(encoding, rtol=rtol, atol=atol) if not amplitudes: return "0" + coefficients = [ + complex_to_str(amplitude, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision) + for amplitude in amplitudes.values() + ] + kets = [_ket_str(ket, output) for ket in amplitudes] + # When every amplitude shares the same magnitude, factor it out, e.g. ``√2/2(|00⟩ + |11⟩)``. + factored = _factor_uniform_magnitude(coefficients, kets) + if factored is not None: + return factored result = "" - for index, (ket, amplitude) in enumerate(amplitudes.items()): - coefficient = complex_to_str( - amplitude, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision - ) - ket_str = _ket_str(ket, output) + for index, (coefficient, ket_str) in enumerate(zip(coefficients, kets, strict=True)): if coefficient == "1": term = ket_str elif coefficient == "-1": diff --git a/graphix/sim/statevec.py b/graphix/sim/statevec.py index cdfafeca5..dfde0dfcc 100644 --- a/graphix/sim/statevec.py +++ b/graphix/sim/statevec.py @@ -567,7 +567,7 @@ def draw( >>> circuit.h(0) >>> circuit.cz(0, 1) >>> print(circuit.simulate_statevector().statevec.draw()) - √2/2|00⟩ + √2/2|01⟩ + √2/2(|00⟩ + |01⟩) """ return statevec_to_str( self, diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index d44f45479..5fd9a2ae3 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -1,5 +1,6 @@ from __future__ import annotations +import cmath import math from typing import TYPE_CHECKING @@ -281,10 +282,11 @@ def test_complex_to_str_values(value: object, expected: Mapping[OutputFormat, st def test_statevec_draw() -> None: bell = Statevec([2**-0.5, 0, 0, 2**-0.5]) - assert bell.draw(OutputFormat.Unicode) == "√2/2|00⟩ + √2/2|11⟩" - assert bell.draw(OutputFormat.ASCII) == "sqrt(2)/2|00> + sqrt(2)/2|11>" + # A magnitude shared by every amplitude is factored out. + assert bell.draw(OutputFormat.Unicode) == "√2/2(|00⟩ + |11⟩)" + assert bell.draw(OutputFormat.ASCII) == "sqrt(2)/2(|00> + |11>)" # LaTeX uses the \ket{...} macro for the basis kets. - assert bell.draw(OutputFormat.LaTeX) == r"\frac{\sqrt{2}}{2}\ket{00} + \frac{\sqrt{2}}{2}\ket{11}" + assert bell.draw(OutputFormat.LaTeX) == r"\frac{\sqrt{2}}{2}(\ket{00} + \ket{11})" def test_statevec_draw_single_basis_state() -> None: @@ -303,9 +305,9 @@ def test_density_matrix_draw() -> None: def test_statevec_draw_negative_and_parenthesized() -> None: - # Negative amplitudes use a `-` separator between terms. + # The shared 1/2 magnitude is factored out, with the signs kept inside the parentheses. neg = Statevec([0.5, -0.5, 0.5, 0.5]) - assert neg.draw(OutputFormat.Unicode) == "1/2|00⟩ - 1/2|01⟩ + 1/2|10⟩ + 1/2|11⟩" + assert neg.draw(OutputFormat.Unicode) == "1/2(|00⟩ - |01⟩ + |10⟩ + |11⟩)" # A compound (cartesian) amplitude is parenthesized before the ket. Build from a # numpy array so the amplitudes are ``numpy.complex128`` (Python's ``complex`` only # gained ``__complex__`` in 3.11, so a bare ``complex`` is rejected on 3.10). @@ -336,3 +338,26 @@ def test_density_matrix_draw_rtol() -> None: dm = DensityMatrix(data=[BasicStates.PLUS]) assert dm.draw(OutputFormat.Unicode) == "[ 1/2 1/2 ]\n[ 1/2 1/2 ]" assert dm.draw(OutputFormat.Unicode, rtol=1e-4) == "[ 1/2 1/2 ]\n[ 1/2 1/2 ]" + + +def test_complex_to_str_edge_cases() -> None: + # A tiny non-zero real collapses to "0" (zero branches of the square-free + # decomposition and the real renderer). + assert complex_to_str(1e-7, OutputFormat.Unicode) == "0" + # A purely imaginary value with no nice form falls back to a decimal imaginary part. + assert complex_to_str(0.234567j, OutputFormat.Unicode) == "0.2346i" + # A recognized phase with an unrecognized radius (π·e^(iπ/4)) cannot use the exponential + # or Cartesian forms and falls back to a decimal. + assert complex_to_str(math.pi * cmath.exp(1j * math.pi / 4), OutputFormat.Unicode) == "2.221+2.221i" + + +def test_statevec_draw_all_amplitudes_dropped() -> None: + # A tolerance large enough to drop every amplitude prints the statevector as "0". + bell = Statevec([2**-0.5, 0, 0, 2**-0.5]) + assert bell.draw(OutputFormat.Unicode, atol=1.0) == "0" + + +def test_statevec_draw_non_uniform_not_factored() -> None: + # Amplitudes with different magnitudes are not factored; each term keeps its coefficient. + state = Statevec([math.sqrt(3) / 2, 0.5]) + assert state.draw(OutputFormat.Unicode) == "√3/2|0⟩ + 1/2|1⟩" From 78a48e7b3d72aca6a18d140df15163b496c143ed Mon Sep 17 00:00:00 2001 From: Vinny010 Date: Sat, 13 Jun 2026 21:07:27 +0400 Subject: [PATCH 8/9] Cover the term-by-term fallback negative separator (codecov patch miss) The uniform-magnitude factoring moved the negative-amplitude case off the term-by-term loop, leaving `result += f" - {term[1:]}"` (pretty_print.py:847) uncovered. Give `test_statevec_draw_non_uniform_not_factored` a negative second amplitude so it flows through the fallback and exercises the ` - ` separator. Co-Authored-By: Claude Opus 4.8 --- tests/test_pretty_print.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 5fd9a2ae3..0a1f083dc 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -359,5 +359,7 @@ def test_statevec_draw_all_amplitudes_dropped() -> None: def test_statevec_draw_non_uniform_not_factored() -> None: # Amplitudes with different magnitudes are not factored; each term keeps its coefficient. - state = Statevec([math.sqrt(3) / 2, 0.5]) - assert state.draw(OutputFormat.Unicode) == "√3/2|0⟩ + 1/2|1⟩" + # The negative second amplitude also exercises the ` - ` separator in the term-by-term + # fallback (the factoring path handles the uniform-magnitude negative case separately). + state = Statevec([math.sqrt(3) / 2, -0.5]) + assert state.draw(OutputFormat.Unicode) == "√3/2|0⟩ - 1/2|1⟩" From 0faf58749a3edbc7c9a82a26d77e18b44b0c84b9 Mon Sep 17 00:00:00 2001 From: Vinny010 Date: Mon, 15 Jun 2026 22:01:53 +0400 Subject: [PATCH 9/9] Address @pranav97nair review: phase-aware magnitude factoring + max_denominator test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Factor the shared modulus even when components differ by a relative phase, so the statevector draw renders e.g. `√2/2(|0⟩ + i|1⟩)` (PLUS_I) and `√2/2(|0⟩ - i|1⟩)` (MINUS_I), not just the all-equal-coefficient case. `_factor_uniform_magnitude` now works on the amplitudes: it factors a common modulus `r` and keeps each per-component phase `amplitude / r` inside the parentheses. Falls back to term-by-term when the moduli differ, the modulus is trivial (1), or the amplitudes are non-numeric (symbolic). - Add tests for the relative-phase factoring (PLUS / PLUS_I / MINUS_I, Unicode + LaTeX) and for the `max_denominator` parameter of `DensityMatrix.draw` / `Statevec.draw` (1/2 -> 0.5 and √2/2 -> 0.7071 when capped), plus direct edge-case coverage of the factoring helper (non-numeric, zero modulus, unit modulus). Co-Authored-By: Claude Opus 4.8 --- graphix/pretty_print.py | 94 +++++++++++++++++++++++++------------- tests/test_pretty_print.py | 62 ++++++++++++++++++++++++- 2 files changed, 123 insertions(+), 33 deletions(-) diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index 4a00c56b5..1c414f2fe 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -747,35 +747,61 @@ def _needs_parentheses(coefficient: str) -> bool: return " + " in coefficient or " - " in coefficient -def _factor_uniform_magnitude(coefficients: Sequence[str], kets: Sequence[str]) -> str | None: - """Factor a magnitude shared by every amplitude, e.g. ``√2/2(|00⟩ + |11⟩)``. - - Returns the factored string when all amplitudes have the same non-unit magnitude - (up to sign), or ``None`` when no such common factor exists (in which case the - caller falls back to the term-by-term rendering). +def _factor_uniform_magnitude( + amplitudes: Sequence[object], + kets: Sequence[str], + output: OutputFormat, + *, + max_denominator: int, + atol: float, + rtol: float, + precision: int, +) -> str | None: + """Factor a modulus shared by every amplitude, e.g. ``√2/2(|0⟩ + i|1⟩)``. + + Amplitudes that share a common modulus (up to a relative phase) are written as + ``r(φ₀|k₀⟩ + φ₁|k₁⟩ + …)``, with ``r`` the shared modulus and each ``φ`` the + per-component phase ``amplitude / r`` (so a relative phase such as ``i`` is kept + inside the parentheses). Returns ``None`` when the moduli differ, the shared + modulus is trivial (``1``), or the amplitudes are not numeric (e.g. symbolic), + in which case the caller renders term-by-term. """ - if len(coefficients) < 2: + if len(amplitudes) < 2: return None - magnitudes: list[str] = [] - signs: list[str] = [] - for coefficient in coefficients: - # A compound (Cartesian) coefficient has no single magnitude to factor out. - if _needs_parentheses(coefficient): - return None - if coefficient.startswith("-"): - signs.append("-") - magnitudes.append(coefficient[1:]) - else: - signs.append("+") - magnitudes.append(coefficient) - common = magnitudes[0] - # Nothing to factor when the magnitude is the unit coefficient, or it is not shared. - if common == "1" or any(magnitude != common for magnitude in magnitudes): + try: + values = [complex(amplitude) for amplitude in amplitudes] # type: ignore[call-overload] + except (TypeError, ValueError): + # Non-numeric (e.g. symbolic) amplitudes have no modulus to factor. + return None + radius = abs(values[0]) + if radius == 0 or any(not math.isclose(abs(v), radius, rel_tol=rtol, abs_tol=atol) for v in values): return None - body = kets[0] if signs[0] == "+" else f"-{kets[0]}" - for sign, ket in zip(signs[1:], kets[1:], strict=True): - body += f" {sign} {ket}" - return f"{common}({body})" + radius_str = complex_to_str( + radius, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision + ) + # A unit modulus is not worth factoring (it would read as ``1(...)``). + if radius_str == "1": + return None + result = "" + for index, (value, ket) in enumerate(zip(values, kets, strict=True)): + phase = complex_to_str( + value / radius, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision + ) + # A per-component phase is unit-modulus, so it is rendered as 1, -1, ±i, an + # exponential, or a decimal — never a parenthesizable sum. + if phase == "1": + term = ket + elif phase == "-1": + term = f"-{ket}" + else: + term = f"{phase}{ket}" + if index == 0: + result = term + elif term.startswith("-"): + result += f" - {term[1:]}" + else: + result += f" + {term}" + return f"{radius_str}({result})" def statevec_to_str( @@ -822,15 +848,19 @@ def statevec_to_str( amplitudes = statevec.to_dict(encoding, rtol=rtol, atol=atol) if not amplitudes: return "0" - coefficients = [ - complex_to_str(amplitude, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision) - for amplitude in amplitudes.values() - ] + amps = list(amplitudes.values()) kets = [_ket_str(ket, output) for ket in amplitudes] - # When every amplitude shares the same magnitude, factor it out, e.g. ``√2/2(|00⟩ + |11⟩)``. - factored = _factor_uniform_magnitude(coefficients, kets) + # When every amplitude shares a modulus (up to a relative phase), factor it out, + # e.g. ``√2/2(|0⟩ + i|1⟩)``. + factored = _factor_uniform_magnitude( + amps, kets, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision + ) if factored is not None: return factored + coefficients = [ + complex_to_str(amplitude, output, max_denominator=max_denominator, atol=atol, rtol=rtol, precision=precision) + for amplitude in amps + ] result = "" for index, (coefficient, ket_str) in enumerate(zip(coefficients, kets, strict=True)): if coefficient == "1": diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 0a1f083dc..87b3febbc 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -16,7 +16,7 @@ from graphix.opengraph import OpenGraph from graphix.parameter import Placeholder from graphix.pattern import Pattern -from graphix.pretty_print import OutputFormat, complex_to_str, pattern_to_str +from graphix.pretty_print import OutputFormat, _factor_uniform_magnitude, complex_to_str, pattern_to_str from graphix.random_objects import rand_circuit from graphix.sim.density_matrix import DensityMatrix from graphix.sim.statevec import Statevec @@ -340,6 +340,66 @@ def test_density_matrix_draw_rtol() -> None: assert dm.draw(OutputFormat.Unicode, rtol=1e-4) == "[ 1/2 1/2 ]\n[ 1/2 1/2 ]" +def test_statevec_draw_factor_relative_phase() -> None: + # A modulus shared up to a relative phase is still factored, with the phase kept + # inside the parentheses. + assert Statevec(data=BasicStates.PLUS).draw(OutputFormat.Unicode) == "√2/2(|0⟩ + |1⟩)" + assert Statevec(data=BasicStates.PLUS_I).draw(OutputFormat.Unicode) == "√2/2(|0⟩ + i|1⟩)" + assert Statevec(data=BasicStates.MINUS_I).draw(OutputFormat.Unicode) == "√2/2(|0⟩ - i|1⟩)" + assert ( + Statevec(data=BasicStates.PLUS_I).draw(OutputFormat.LaTeX) == r"\frac{\sqrt{2}}{2}(\ket{0} + \mathrm{i}\ket{1})" + ) + + +def test_factor_uniform_magnitude_edge_cases() -> None: + kets = ["|0⟩", "|1⟩"] + # Non-numeric (e.g. symbolic) amplitudes have no modulus -> None (term-by-term fallback). + assert ( + _factor_uniform_magnitude( + ["alpha", "beta"], kets, OutputFormat.Unicode, max_denominator=1000, atol=1e-9, rtol=0.0, precision=4 + ) + is None + ) + # Zero modulus -> None. + assert ( + _factor_uniform_magnitude( + [0, 0], kets, OutputFormat.Unicode, max_denominator=1000, atol=1e-9, rtol=0.0, precision=4 + ) + is None + ) + # A unit modulus is not worth factoring -> None. + assert ( + _factor_uniform_magnitude( + [1, 1j], kets, OutputFormat.Unicode, max_denominator=1000, atol=1e-9, rtol=0.0, precision=4 + ) + is None + ) + # A non-nice relative phase still factors, with the phase as a decimal coefficient. + factored = _factor_uniform_magnitude( + [0.5, 0.5 * cmath.exp(1j * 0.3)], + kets, + OutputFormat.Unicode, + max_denominator=1000, + atol=1e-9, + rtol=0.0, + precision=4, + ) + assert factored is not None + assert factored.startswith("1/2(|0⟩ + ") + + +def test_draw_max_denominator() -> None: + # `max_denominator` caps the denominators the recognition will accept. + dm = DensityMatrix(data=[BasicStates.PLUS]) + assert dm.draw(OutputFormat.Unicode) == "[ 1/2 1/2 ]\n[ 1/2 1/2 ]" + # With max_denominator=1, 1/2 can no longer be recognized and falls back to a decimal. + assert dm.draw(OutputFormat.Unicode, max_denominator=1) == "[ 0.5 0.5 ]\n[ 0.5 0.5 ]" + # Same effect on the statevector draw (√2/2 -> 0.7071). + sv = Statevec([2**-0.5, 2**-0.5]) + assert sv.draw(OutputFormat.Unicode) == "√2/2(|0⟩ + |1⟩)" + assert sv.draw(OutputFormat.Unicode, max_denominator=1) == "0.7071(|0⟩ + |1⟩)" + + def test_complex_to_str_edge_cases() -> None: # A tiny non-zero real collapses to "0" (zero branches of the square-free # decomposition and the real renderer).