From dbc7900022c2090953d806cd3ebc5f733a127f44 Mon Sep 17 00:00:00 2001 From: Aidan Sims Date: Wed, 3 Jun 2026 12:00:57 -0700 Subject: [PATCH 1/6] Add pretty printing for statevectors and density matrices First, implements function that pretty prints a complex number. Then, adds to statevecotr and density matrix. The density matrix is printed in matrix form. --- docs/source/visualization.rst | 6 + graphix/pretty_print.py | 424 +++++++++++++++++++++++++++++++++- graphix/sim/density_matrix.py | 75 +++++- graphix/sim/statevec.py | 33 +++ tests/test_density_matrix.py | 47 ++++ tests/test_pretty_print.py | 41 +++- tests/test_statevec.py | 25 ++ 7 files changed, 641 insertions(+), 10 deletions(-) diff --git a/docs/source/visualization.rst b/docs/source/visualization.rst index 489838cbd..a48628c6c 100644 --- a/docs/source/visualization.rst +++ b/docs/source/visualization.rst @@ -23,6 +23,12 @@ This modules provides functions to format patterns and flows. .. autofunction:: angle_to_str +.. autofunction:: complex_to_str + +.. autofunction:: statevec_to_str + +.. autofunction:: density_matrix_to_str + .. autofunction:: command_to_str .. autofunction:: pattern_to_str diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index b2bd72e8f..8cea29c0a 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -2,30 +2,42 @@ from __future__ import annotations +import cmath import enum import math import string from enum import Enum from fractions import Fraction from math import pi -from typing import TYPE_CHECKING, SupportsFloat +from typing import TYPE_CHECKING, Literal, SupportsComplex, SupportsFloat # `assert_never` introduced in Python 3.11 from typing_extensions import assert_never from graphix import command -from graphix.fundamentals import AbstractMeasurement, Axis, Plane, Sign, angle_to_rad, rad_to_angle +from graphix.fundamentals import ( + AbstractMeasurement, + Axis, + Plane, + Sign, + angle_to_rad, + rad_to_angle, +) from graphix.measurements import BlochMeasurement, PauliMeasurement from graphix.parameter import AffineExpression if TYPE_CHECKING: - from collections.abc import Container, Iterable, Mapping, Sequence + from collections.abc import Callable, Container, Iterable, Mapping, Sequence from collections.abc import Set as AbstractSet from graphix.command import Node from graphix.flow.core import PauliFlow, XZCorrections from graphix.fundamentals import Angle from graphix.pattern import Pattern + from graphix.sim.density_matrix import DensityMatrix + from graphix.sim.statevec import Statevec + +_ENCODING = Literal["LSB", "MSB"] class OutputFormat(Enum): @@ -37,7 +49,11 @@ class OutputFormat(Enum): def angle_to_str( - angle: Angle, output: OutputFormat, max_denominator: int = 1000, multiplication_sign: bool = False + angle: Angle, + output: OutputFormat, + max_denominator: int = 1000, + multiplication_sign: bool = False, + abs_tol: float = 0.0, ) -> str: r""" Return a string representation of an angle given in units of π. @@ -61,6 +77,8 @@ def angle_to_str( ``2×π`` in Unicode, ``2 \times \pi`` in LaTeX, and ``2*pi`` in ASCII. If ``False``, the multiplication sign is implicit: ``2π`` in Unicode, ``2\pi`` in LaTeX, ``2pi`` in ASCII. + abs_tol : float, optional + Absolute tolerance passed to :func:`math.isclose` (default: ``0.0``). Returns ------- @@ -69,7 +87,7 @@ def angle_to_str( """ frac = Fraction(angle).limit_denominator(max_denominator) - if not math.isclose(angle, float(frac)): + if not math.isclose(angle, float(frac), abs_tol=abs_tol): rad = angle_to_rad(angle) return f"{rad}" @@ -110,6 +128,391 @@ def mkfrac(num: str, den: str) -> str: return f"{sign}{mkfrac(num_str, den_str)}" +_MAX_RADICAND = 10 + + +def _format_helpers( + output: OutputFormat, +) -> tuple[Callable[[str, str], str], Callable[[int], str]]: + if output == OutputFormat.LaTeX: + + def mkfrac(num: str, den: str) -> str: + return rf"\frac{{{num}}}{{{den}}}" + + def sqrt(n: int) -> str: + return rf"\sqrt{{{n}}}" + + elif output == OutputFormat.Unicode: + + def mkfrac(num: str, den: str) -> str: + return f"{num}/{den}" + + def sqrt(n: int) -> str: + return f"√{n}" + + else: + + def mkfrac(num: str, den: str) -> str: + return f"{num}/{den}" + + def sqrt(n: int) -> str: + return f"sqrt({n})" + + return mkfrac, sqrt + + +def _real_scalar_to_str( + x: float, + output: OutputFormat, + *, + max_denominator: int = 1000, + rel_tol: float = 1e-9, + abs_tol: float = 1e-8, + max_radicand: int = _MAX_RADICAND, +) -> str | None: + mkfrac, sqrt = _format_helpers(output) + + frac = Fraction(x).limit_denominator(max_denominator) + if math.isclose(x, float(frac), rel_tol=rel_tol, abs_tol=abs_tol): + num, den = frac.numerator, frac.denominator + sign = "-" if num < 0 else "" + num = abs(num) + if den == 1: + return f"{sign}{num}" + return f"{sign}{mkfrac(str(num), str(den))}" + + for n in range(1, max_radicand + 1): + root = math.sqrt(n) + for d in range(1, max_denominator + 1): + val = root / d + if math.isclose(x, val, rel_tol=rel_tol, abs_tol=abs_tol): + num_str = sqrt(n) if n != 1 else "1" + return num_str if d == 1 else mkfrac(num_str, str(d)) + if math.isclose(x, -val, rel_tol=rel_tol, abs_tol=abs_tol): + num_str = sqrt(n) if n != 1 else "1" + formatted = num_str if d == 1 else mkfrac(num_str, str(d)) + return f"-{formatted}" + + return None + + +def _scalar_or_decimal( + x: float, + output: OutputFormat, + *, + max_denominator: int, + rel_tol: float, + abs_tol: float, +) -> str: + result = _real_scalar_to_str(x, output, max_denominator=max_denominator, rel_tol=rel_tol, abs_tol=abs_tol) + if result: + return result + return f"{x:g}" + + +def _imag_unit_str(coeff: str) -> str: + return "i" if coeff == "1" else f"{coeff} i" + + +def _exp_i_to_str(angle_str: str, output: OutputFormat) -> str: + match output: + case OutputFormat.LaTeX: + return rf"e^{{i{angle_str}}}" + case OutputFormat.Unicode: + return f"e^(i{angle_str})" + case OutputFormat.ASCII: + return f"e^(i*{angle_str})" + case _: + assert_never(output) + + +def _cartesian_to_str( + z: complex, + output: OutputFormat, + *, + max_denominator: int, + rel_tol: float, + abs_tol: float, +) -> str: + real_zero = math.isclose(z.real, 0.0, rel_tol=rel_tol, abs_tol=abs_tol) + imag_zero = math.isclose(z.imag, 0.0, rel_tol=rel_tol, abs_tol=abs_tol) + + if imag_zero: + return _scalar_or_decimal( + z.real, + output, + max_denominator=max_denominator, + rel_tol=rel_tol, + abs_tol=abs_tol, + ) + + imag_coeff = _scalar_or_decimal( + abs(z.imag), + output, + max_denominator=max_denominator, + rel_tol=rel_tol, + abs_tol=abs_tol, + ) + imag_part = _imag_unit_str(imag_coeff) + + if real_zero: + return f"-{imag_part}" if z.imag < 0 else imag_part + + real_str = _scalar_or_decimal( + z.real, + output, + max_denominator=max_denominator, + rel_tol=rel_tol, + abs_tol=abs_tol, + ) + if z.imag >= 0: + return f"{real_str} + {imag_part}" + return f"{real_str} - {imag_part}" + + +def complex_to_str( + z: complex | SupportsComplex, + output: OutputFormat, + *, + max_denominator: int = 1000, + rel_tol: float = 1e-9, + abs_tol: float = 1e-8, +) -> str: + r""" + Return a string representation of a complex number. + + Common values are rendered symbolically: + + - rational reals such as ``0.25`` as ``1/4``, + - radical rationals such as ``0.70710678`` as ``√2/2``, + - unit-modulus values such as ``0.5 + 0.8660254j`` as ``e^(iπ/3)``. + + Parameters + ---------- + z : complex + The complex number to format. + output : OutputFormat + Desired formatting style: Unicode, LaTeX, or ASCII. + max_denominator : int, optional + Maximum denominator for detecting simple fractions and radical forms (default: 1000). + rel_tol : float, optional + Relative tolerance passed to :func:`math.isclose` (default: ``1e-9``). + abs_tol : float, optional + Absolute tolerance passed to :func:`math.isclose` (default: ``1e-8``). + + Returns + ------- + str + The formatted complex number. + """ + z = complex(z) + + if math.isclose(z.real, 0.0, rel_tol=rel_tol, abs_tol=abs_tol) and math.isclose( + z.imag, 0.0, rel_tol=rel_tol, abs_tol=abs_tol + ): + return "0" + + if math.isclose(abs(z), 1.0, rel_tol=rel_tol, abs_tol=abs_tol): + if math.isclose(z.real, 1.0, rel_tol=rel_tol, abs_tol=abs_tol) and math.isclose( + z.imag, 0.0, rel_tol=rel_tol, abs_tol=abs_tol + ): + return "1" + if math.isclose(z.real, -1.0, rel_tol=rel_tol, abs_tol=abs_tol) and math.isclose( + z.imag, 0.0, rel_tol=rel_tol, abs_tol=abs_tol + ): + return "-1" + if math.isclose(z.real, 0.0, rel_tol=rel_tol, abs_tol=abs_tol) and math.isclose( + z.imag, 1.0, rel_tol=rel_tol, abs_tol=abs_tol + ): + return "i" + if math.isclose(z.real, 0.0, rel_tol=rel_tol, abs_tol=abs_tol) and math.isclose( + z.imag, -1.0, rel_tol=rel_tol, abs_tol=abs_tol + ): + return "-i" + + angle = cmath.phase(z) / pi + frac = Fraction(angle).limit_denominator(max_denominator) + if math.isclose(angle, float(frac), rel_tol=rel_tol, abs_tol=abs_tol): + angle_str = angle_to_str(angle, output, max_denominator=max_denominator, abs_tol=abs_tol) + return _exp_i_to_str(angle_str, output) + + return _cartesian_to_str(z, output, max_denominator=max_denominator, rel_tol=rel_tol, abs_tol=abs_tol) + + +def _ket_to_str(bits: str, output: OutputFormat) -> str: + match output: + case OutputFormat.LaTeX: + return rf"\ket{{{bits}}}" + case OutputFormat.Unicode: + return f"|{bits}⟩" + case OutputFormat.ASCII: + return f"|{bits}>" + case _: + assert_never(output) + + +def _format_statevec_term( + amp: complex, + ket: str, + output: OutputFormat, + *, + max_denominator: int, + rel_tol: float, + abs_tol: float, +) -> tuple[str, str]: + coeff = complex_to_str(amp, output, max_denominator=max_denominator, rel_tol=rel_tol, abs_tol=abs_tol) + ket_str = _ket_to_str(ket, output) + if coeff == "1": + return "+", ket_str + if coeff == "-1": + return "-", ket_str + if coeff.startswith("-"): + return "-", f"{coeff[1:]}{ket_str}" + return "+", f"{coeff}{ket_str}" + + +def _join_statevec_terms(terms: list[tuple[str, str]]) -> str: + if not terms: + return "0" + sign, body = terms[0] + result = f"-{body}" if sign == "-" else body + for sign, body in terms[1:]: + result += f" - {body}" if sign == "-" else f" + {body}" + return result + + +def statevec_to_str( + statevec: Statevec, + output: OutputFormat, + encoding: _ENCODING = "MSB", + *, + rtol: float = 0.0, + atol: float = 1e-8, + max_denominator: int = 1000, + rel_tol: float = 1e-9, + abs_tol: float = 1e-8, +) -> str: + r""" + Return a string representation of a statevector in ket notation. + + Uses :meth:`graphix.sim.statevec.Statevec.to_dict` to obtain non-zero amplitudes, + and formats each amplitude with :func:`complex_to_str`. + + Parameters + ---------- + statevec : Statevec + The statevector to format. + output : OutputFormat + Desired formatting style: Unicode, LaTeX, or ASCII. + encoding : Literal["LSB", "MSB"], default="MSB" + Encoding for the basis kets. See :meth:`graphix.sim.statevec.Statevec.to_dict`. + rtol : float, default=0.0 + Relative tolerance for filtering zero amplitudes, passed to :meth:`Statevec.to_dict`. + atol : float, default=1e-8 + Absolute tolerance for filtering zero amplitudes, passed to :meth:`Statevec.to_dict`. + max_denominator : int, optional + Maximum denominator for detecting simple amplitudes (default: 1000). + rel_tol : float, optional + Relative tolerance passed to :func:`complex_to_str` (default: ``1e-9``). + abs_tol : float, optional + Absolute tolerance passed to :func:`complex_to_str` (default: ``1e-8``). + + Returns + ------- + str + The formatted statevector as a sum of ket terms. + """ + amplitudes = statevec.to_dict(encoding=encoding, rtol=rtol, atol=atol) + terms = [ + _format_statevec_term( + complex(amp), + ket, + output, + max_denominator=max_denominator, + rel_tol=rel_tol, + abs_tol=abs_tol, + ) + for ket, amp in amplitudes.items() + ] + result = _join_statevec_terms(terms) + if output == OutputFormat.LaTeX: + return f"\\({result}\\)" + return result + + +def density_matrix_to_str( + density_matrix: DensityMatrix, + output: OutputFormat, + *, + rtol: float = 0.0, + atol: float = 1e-8, + max_denominator: int = 1000, + rel_tol: float = 1e-9, + abs_tol: float = 1e-8, +) -> str: + r""" + Return a string representation of a density matrix. + + Formats each matrix element with :func:`complex_to_str`. + + Parameters + ---------- + density_matrix : DensityMatrix + The density matrix to format. + output : OutputFormat + Desired formatting style: Unicode, LaTeX, or ASCII. + rtol : float, default=0.0 + Relative tolerance for displaying negligible elements as ``0``. + atol : float, default=1e-8 + Absolute tolerance for displaying negligible elements as ``0``. + max_denominator : int, optional + Maximum denominator for detecting simple matrix elements (default: 1000). + rel_tol : float, optional + Relative tolerance passed to :func:`complex_to_str` (default: ``1e-9``). + abs_tol : float, optional + Absolute tolerance passed to :func:`complex_to_str` (default: ``1e-8``). + + Returns + ------- + str + The formatted density matrix. + """ + rho = density_matrix.rho + nrows, ncols = rho.shape + + def format_cell(value: complex) -> str: + if math.isclose(abs(value), 0.0, rel_tol=rtol, abs_tol=atol): + return "0" + return complex_to_str( + value, + output, + max_denominator=max_denominator, + rel_tol=rel_tol, + abs_tol=abs_tol, + ) + + cells = [[format_cell(complex(rho[i, j])) for j in range(ncols)] for i in range(nrows)] + col_widths = [max(len(cells[i][j]) for i in range(nrows)) for j in range(ncols)] + + if output == OutputFormat.LaTeX: + rows = [" & ".join(cell.rjust(col_widths[j]) for j, cell in enumerate(row)) for row in cells] + body = r" \\ ".join(rows) + return rf"\(\begin{{pmatrix}} {body} \end{{pmatrix}}\)" + + lines: list[str] = [] + for i, row in enumerate(cells): + inner = ", ".join(cell.rjust(col_widths[j]) for j, cell in enumerate(row)) + if nrows == 1: + lines.append(f"[[{inner}]]") + elif i == 0: + lines.append(f"[[{inner}],") + elif i == nrows - 1: + lines.append(f" [{inner}]]") + else: + lines.append(f" [{inner}],") + return "\n".join(lines) + + def domain_to_str(domain: set[Node]) -> str: """Return the string representation of a domain.""" return f"{{{','.join(str(node) for node in domain)}}}" @@ -299,7 +702,10 @@ def set_to_str(objects: Iterable[object], output: OutputFormat) -> str: def correction_function_to_str( - correction_function: Mapping[int, AbstractSet[int]], cf_name: str, output: OutputFormat, multiline: bool = False + correction_function: Mapping[int, AbstractSet[int]], + cf_name: str, + output: OutputFormat, + multiline: bool = False, ) -> str: """Convert a correction function mapping to a formatted string representation. @@ -411,7 +817,11 @@ def flow_to_str(flow: PauliFlow[AbstractMeasurement], output: OutputFormat, mult ) -def xzcorr_to_str(xzcorr: XZCorrections[AbstractMeasurement], output: OutputFormat, multiline: bool = False) -> str: +def xzcorr_to_str( + xzcorr: XZCorrections[AbstractMeasurement], + output: OutputFormat, + multiline: bool = False, +) -> str: """Convert an XZCorrections object to a formatted string representation. Parameters diff --git a/graphix/sim/density_matrix.py b/graphix/sim/density_matrix.py index cb5570e2a..53195edc9 100644 --- a/graphix/sim/density_matrix.py +++ b/graphix/sim/density_matrix.py @@ -19,18 +19,21 @@ from graphix import parameter from graphix.channels import KrausChannel from graphix.parameter import Expression, ExpressionOrFloat, ExpressionOrSupportsComplex +from graphix.pretty_print import OutputFormat, density_matrix_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, matmul, outer, tensordot, vdot -from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec +from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec, _format_encoding from graphix.states import BasicStates, State if TYPE_CHECKING: from collections.abc import Mapping, Sequence - from typing import SupportsComplex, SupportsFloat + from typing import Literal, SupportsComplex, SupportsFloat from graphix.noise_models.noise_model import Noise from graphix.parameter import ExpressionOrSupportsFloat, Parameter from graphix.sim.data import Data + _ENCODING = Literal["LSB", "MSB"] + class DensityMatrix(DenseState): """DensityMatrix object.""" @@ -354,6 +357,74 @@ def flatten(self) -> Matrix: """Return flattened density matrix.""" return self.rho.flatten() + def to_dict( + self, + encoding: _ENCODING = "MSB", + *, + rtol: float = 0.0, + atol: float = 1e-8, + ) -> dict[tuple[str, str], np.complex128]: + r"""Convert the density matrix to dictionary form. + + This dictionary representation uses ket-bra notation where the dictionary keys + are ``(ket, bra)`` pairs of qubit strings for the basis vectors and values are + the corresponding complex matrix elements. Elements below a certain threshold + are filtered out. + + Parameters + ---------- + encoding : Literal["LSB", "MSB"], default="MSB" + Encoding for the basis kets and bras. See :meth:`graphix.sim.statevec.Statevec.to_dict`. + rtol : float, default=0.0 + Relative tolerance used when deciding whether an element should be treated as zero. + atol : float, default=1e-8 + Absolute tolerance used when deciding whether an element should be treated as zero. + + Returns + ------- + dict[tuple[str, str], complex] + The density matrix in dictionary form. + """ + result: dict[tuple[str, str], np.complex128] = {} + for i in range(2**self.nqubit): + for j in range(2**self.nqubit): + element = self.rho[i, j] + if np.isclose(np.abs(element), 0, rtol=rtol, atol=atol): + continue + ket = _format_encoding(self.nqubit, i, encoding) + bra = _format_encoding(self.nqubit, j, encoding) + result[(ket, bra)] = element + return result + + def draw( + self, + output: OutputFormat = OutputFormat.Unicode, + *, + rtol: float = 0.0, + atol: float = 1e-8, + ) -> str: + r"""Return a pretty-printed string representation of the density matrix. + + Parameters + ---------- + output : OutputFormat, default=OutputFormat.Unicode + Desired formatting style: Unicode, LaTeX, or ASCII. + rtol : float, default=0.0 + Relative tolerance for displaying negligible elements as ``0``. + atol : float, default=1e-8 + Absolute tolerance for displaying negligible elements as ``0``. + + Returns + ------- + str + The formatted density matrix. + + See Also + -------- + :func:`graphix.pretty_print.density_matrix_to_str` + """ + return density_matrix_to_str(self, output, rtol=rtol, atol=atol) + def apply_channel(self, channel: KrausChannel, qargs: Sequence[int]) -> None: """Apply a channel to a density matrix. diff --git a/graphix/sim/statevec.py b/graphix/sim/statevec.py index 1f7dd23e6..c67860de2 100644 --- a/graphix/sim/statevec.py +++ b/graphix/sim/statevec.py @@ -16,6 +16,7 @@ from graphix import parameter, states from graphix.parameter import Expression, ExpressionOrSupportsComplex, check_expression_or_float +from graphix.pretty_print import OutputFormat, statevec_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, tensordot from graphix.states import BasicStates @@ -486,6 +487,38 @@ def to_dict( """ return self._to_dict_map(lambda x: x, encoding, rtol=rtol, atol=atol) + def draw( + self, + encoding: _ENCODING = "MSB", + output: OutputFormat = OutputFormat.Unicode, + *, + rtol: float = 0.0, + atol: float = 1e-8, + ) -> str: + r"""Return a pretty-printed string representation of the statevector in ket notation. + + Parameters + ---------- + encoding : Literal["LSB", "MSB"], default="MSB" + Encoding for the basis kets. See :meth:`to_dict` for additional information. + output : OutputFormat, default=OutputFormat.Unicode + Desired formatting style: Unicode, LaTeX, or ASCII. + rtol : float, default=0.0 + Relative tolerance for filtering zero amplitudes. See :meth:`to_dict`. + atol : float, default=1e-8 + Absolute tolerance for filtering zero amplitudes. See :meth:`to_dict`. + + Returns + ------- + str + The formatted statevector as a sum of ket terms. + + See Also + -------- + :func:`graphix.pretty_print.statevec_to_str` + """ + return statevec_to_str(self, output, encoding, rtol=rtol, atol=atol) + def to_prob_dict( self, encoding: _ENCODING = "MSB", *, rtol: float = 0.0, atol: float = 1e-8 ) -> dict[str, np.object_ | np.float64]: diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index 7c67909c5..9a5cedef9 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -15,6 +15,7 @@ from graphix.channels import KrausChannel, dephasing_channel, depolarising_channel from graphix.fundamentals import ANGLE_PI, Plane from graphix.ops import Ops +from graphix.pretty_print import OutputFormat, density_matrix_to_str from graphix.sim.density_matrix import DensityMatrix, DensityMatrixBackend from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec from graphix.simulator import DefaultMeasureMethod @@ -929,3 +930,49 @@ def test_measure(self, outcome: Outcome) -> None: else np.kron(np.array([[1, 0], [0, 0]]), np.ones((2, 2)) / 2) ) assert np.allclose(backend.state.rho, expected_matrix) + + +@pytest.mark.parametrize( + ("output", "expected"), + [ + (OutputFormat.Unicode, "[[0, 0, 0, 0],\n [0, 1, 0, 0],\n [0, 0, 0, 0],\n [0, 0, 0, 0]]"), + (OutputFormat.ASCII, "[[0, 0, 0, 0],\n [0, 1, 0, 0],\n [0, 0, 0, 0],\n [0, 0, 0, 0]]"), + ( + OutputFormat.LaTeX, + r"\(\begin{pmatrix} 0 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 \end{pmatrix}\)", + ), + ], +) +def test_density_matrix_draw_pure_state(output: OutputFormat, expected: str) -> None: + dm = DensityMatrix(data=[BasicStates.ZERO, BasicStates.ONE]) + assert dm.rho.shape == (4, 4) + assert dm.draw(output=output) == expected + assert density_matrix_to_str(dm, output) == expected + + +def test_density_matrix_draw_superposition() -> None: + dm = DensityMatrix(data=BasicStates.PLUS) + assert dm.draw(output=OutputFormat.Unicode) == "[[1/2, 1/2],\n [1/2, 1/2]]" + assert density_matrix_to_str(dm, OutputFormat.Unicode) == "[[1/2, 1/2],\n [1/2, 1/2]]" + assert dm.draw(output=OutputFormat.LaTeX) == ( + r"\(\begin{pmatrix} \frac{1}{2} & \frac{1}{2} \\ \frac{1}{2} & \frac{1}{2} \end{pmatrix}\)" + ) + + +def test_density_matrix_draw_4x4() -> None: + sv = Statevec(data=np.array([1, 0, 0, 1], dtype=np.complex128) / np.sqrt(2)) + dm = DensityMatrix(data=sv) + assert dm.rho.shape == (4, 4) + assert ( + dm.draw(output=OutputFormat.Unicode) + == "[[1/2, 0, 0, 1/2],\n [ 0, 0, 0, 0],\n [ 0, 0, 0, 0],\n [1/2, 0, 0, 1/2]]" + ) + assert density_matrix_to_str(dm, OutputFormat.Unicode) == ( + "[[1/2, 0, 0, 1/2],\n [ 0, 0, 0, 0],\n [ 0, 0, 0, 0],\n [1/2, 0, 0, 1/2]]" + ) + assert dm.draw(output=OutputFormat.LaTeX) == ( + r"\(\begin{pmatrix} \frac{1}{2} & 0 & 0 & \frac{1}{2} \\ " + r" 0 & 0 & 0 & 0 \\ " + r" 0 & 0 & 0 & 0 \\ " + r"\frac{1}{2} & 0 & 0 & \frac{1}{2} \end{pmatrix}\)" + ) diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 28dedce55..3221b7353 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -1,5 +1,7 @@ from __future__ import annotations +import cmath +import math from typing import TYPE_CHECKING import networkx as nx @@ -13,7 +15,7 @@ from graphix.opengraph import OpenGraph from graphix.parameter import Placeholder from graphix.pattern import Pattern -from graphix.pretty_print import OutputFormat, pattern_to_str +from graphix.pretty_print import OutputFormat, complex_to_str, pattern_to_str from graphix.random_objects import rand_circuit from graphix.transpiler import Circuit @@ -202,3 +204,40 @@ def test_xzcorr_str() -> None: str(flow) == "x(3) = {5}, x(4) = {6}, x(1) = {3}, x(2) = {4}; z(1) = {4, 5}, z(2) = {3, 6}; {1, 2} < {3, 4} < {5, 6}" ) + + +@pytest.mark.parametrize( + ("z", "output", "expected"), + [ + (0.25, OutputFormat.ASCII, "1/4"), + (0.25, OutputFormat.Unicode, "1/4"), + (0.25, OutputFormat.LaTeX, r"\frac{1}{4}"), + (0.25 + 0j, OutputFormat.Unicode, "1/4"), + (0.70710678, OutputFormat.ASCII, "sqrt(2)/2"), + (0.70710678, OutputFormat.Unicode, "√2/2"), + (0.70710678, OutputFormat.LaTeX, r"\frac{\sqrt{2}}{2}"), + (math.sqrt(3) / 2, OutputFormat.Unicode, "√3/2"), + (math.sqrt(3) / 2, OutputFormat.LaTeX, r"\frac{\sqrt{3}}{2}"), + (0.5 + 0.8660254j, OutputFormat.ASCII, "e^(i*pi/3)"), + (0.5 + 0.8660254j, OutputFormat.Unicode, "e^(iπ/3)"), + (0.5 + 0.8660254j, OutputFormat.LaTeX, r"e^{i\frac{\pi}{3}}"), + (cmath.exp(1j * math.pi / 3), OutputFormat.Unicode, "e^(iπ/3)"), + (0, OutputFormat.Unicode, "0"), + (1, OutputFormat.Unicode, "1"), + (-1, OutputFormat.Unicode, "-1"), + (1j, OutputFormat.Unicode, "i"), + (-1j, OutputFormat.Unicode, "-i"), + (0.25 + 0.25j, OutputFormat.ASCII, "1/4 + 1/4 i"), + (0.25 + 0.25j, OutputFormat.Unicode, "1/4 + 1/4 i"), + (0.25 + 0.25j, OutputFormat.LaTeX, r"\frac{1}{4} + \frac{1}{4} i"), + ], +) +def test_complex_to_str(z: complex, output: OutputFormat, expected: str) -> None: + assert complex_to_str(z, output) == expected + + +def test_complex_to_str_fallback() -> None: + z = 0.123 + 0.456j + assert complex_to_str(z, OutputFormat.ASCII, max_denominator=1) == "0.123 + 0.456 i" + assert complex_to_str(z, OutputFormat.Unicode, max_denominator=1) == "0.123 + 0.456 i" + assert complex_to_str(z, OutputFormat.LaTeX, max_denominator=1) == "0.123 + 0.456 i" diff --git a/tests/test_statevec.py b/tests/test_statevec.py index e7ce8d925..b0df79560 100644 --- a/tests/test_statevec.py +++ b/tests/test_statevec.py @@ -8,6 +8,7 @@ from graphix.fundamentals import ANGLE_PI, Plane from graphix.pattern import Pattern +from graphix.pretty_print import OutputFormat, statevec_to_str from graphix.sim.statevec import Statevec, _norm_numeric from graphix.states import BasicStates, PlanarState @@ -239,6 +240,30 @@ def test_to_prob_dict(self, encoding: _ENCODING, dict_ref: Mapping[str, float]) assert np.isclose(0, amp2.imag) +@pytest.mark.parametrize( + ("encoding", "output", "expected"), + [ + ("MSB", OutputFormat.Unicode, "|01⟩"), + ("MSB", OutputFormat.ASCII, "|01>"), + ("MSB", OutputFormat.LaTeX, r"\(\ket{01}\)"), + ("LSB", OutputFormat.Unicode, "|10⟩"), + ], +) +def test_statevec_draw_single_ket(encoding: _ENCODING, output: OutputFormat, expected: str) -> None: + sv = Statevec(data=[BasicStates.ZERO, BasicStates.ONE]) + assert sv.draw(encoding=encoding, output=output) == expected + assert statevec_to_str(sv, output, encoding=encoding) == expected + + +def test_statevec_draw_superposition() -> None: + sv = Statevec(data=[BasicStates.ZERO, BasicStates.PLUS, BasicStates.MINUS]) + assert sv.draw(encoding="MSB", output=OutputFormat.Unicode) == "1/2|000⟩ - 1/2|001⟩ + 1/2|010⟩ - 1/2|011⟩" + assert statevec_to_str(sv, OutputFormat.Unicode, encoding="MSB") == ("1/2|000⟩ - 1/2|001⟩ + 1/2|010⟩ - 1/2|011⟩") + assert sv.draw(encoding="MSB", output=OutputFormat.LaTeX) == ( + r"\(\frac{1}{2}\ket{000} - \frac{1}{2}\ket{001} + \frac{1}{2}\ket{010} - \frac{1}{2}\ket{011}\)" + ) + + def test_normalize() -> None: statevec = Statevec(nqubit=1, data=BasicStates.PLUS) statevec.remove_qubit(0) From 82bef3d39e28149e894e163c43bdd5fe64b0b0f1 Mon Sep 17 00:00:00 2001 From: Aidan Sims Date: Wed, 3 Jun 2026 12:34:47 -0700 Subject: [PATCH 2/6] remove dm to dict --- graphix/sim/density_matrix.py | 45 ++--------------------------------- 1 file changed, 2 insertions(+), 43 deletions(-) diff --git a/graphix/sim/density_matrix.py b/graphix/sim/density_matrix.py index 53195edc9..1471f20d6 100644 --- a/graphix/sim/density_matrix.py +++ b/graphix/sim/density_matrix.py @@ -21,19 +21,17 @@ from graphix.parameter import Expression, ExpressionOrFloat, ExpressionOrSupportsComplex from graphix.pretty_print import OutputFormat, density_matrix_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, matmul, outer, tensordot, vdot -from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec, _format_encoding +from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec from graphix.states import BasicStates, State if TYPE_CHECKING: from collections.abc import Mapping, Sequence - from typing import Literal, SupportsComplex, SupportsFloat + from typing import SupportsComplex, SupportsFloat from graphix.noise_models.noise_model import Noise from graphix.parameter import ExpressionOrSupportsFloat, Parameter from graphix.sim.data import Data - _ENCODING = Literal["LSB", "MSB"] - class DensityMatrix(DenseState): """DensityMatrix object.""" @@ -357,45 +355,6 @@ def flatten(self) -> Matrix: """Return flattened density matrix.""" return self.rho.flatten() - def to_dict( - self, - encoding: _ENCODING = "MSB", - *, - rtol: float = 0.0, - atol: float = 1e-8, - ) -> dict[tuple[str, str], np.complex128]: - r"""Convert the density matrix to dictionary form. - - This dictionary representation uses ket-bra notation where the dictionary keys - are ``(ket, bra)`` pairs of qubit strings for the basis vectors and values are - the corresponding complex matrix elements. Elements below a certain threshold - are filtered out. - - Parameters - ---------- - encoding : Literal["LSB", "MSB"], default="MSB" - Encoding for the basis kets and bras. See :meth:`graphix.sim.statevec.Statevec.to_dict`. - rtol : float, default=0.0 - Relative tolerance used when deciding whether an element should be treated as zero. - atol : float, default=1e-8 - Absolute tolerance used when deciding whether an element should be treated as zero. - - Returns - ------- - dict[tuple[str, str], complex] - The density matrix in dictionary form. - """ - result: dict[tuple[str, str], np.complex128] = {} - for i in range(2**self.nqubit): - for j in range(2**self.nqubit): - element = self.rho[i, j] - if np.isclose(np.abs(element), 0, rtol=rtol, atol=atol): - continue - ket = _format_encoding(self.nqubit, i, encoding) - bra = _format_encoding(self.nqubit, j, encoding) - result[(ket, bra)] = element - return result - def draw( self, output: OutputFormat = OutputFormat.Unicode, From e373cfb6501f7f5972af8498c4b0aa0c88cc1af3 Mon Sep 17 00:00:00 2001 From: Aidan Sims Date: Thu, 4 Jun 2026 18:52:49 -0700 Subject: [PATCH 3/6] code coverage, mathrm, imaginary formatting --- graphix/pretty_print.py | 93 +++++++++++++++++++++++++----------- tests/test_density_matrix.py | 7 +++ tests/test_pretty_print.py | 56 ++++++++++++++++++---- tests/test_statevec.py | 13 +++++ 4 files changed, 133 insertions(+), 36 deletions(-) diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index 8cea29c0a..41b1dd7eb 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -48,6 +48,10 @@ class OutputFormat(Enum): Unicode = enum.auto() +def _validate_output_format(output: OutputFormat) -> None: + assert output in (OutputFormat.ASCII, OutputFormat.LaTeX, OutputFormat.Unicode) + + def angle_to_str( angle: Angle, output: OutputFormat, @@ -85,12 +89,11 @@ def angle_to_str( str The formatted angle. """ + _validate_output_format(output) frac = Fraction(angle).limit_denominator(max_denominator) if not math.isclose(angle, float(frac), abs_tol=abs_tol): - rad = angle_to_rad(angle) - - return f"{rad}" + return f"{angle_to_rad(angle):.2f}" num, den = frac.numerator, frac.denominator sign = "-" if num < 0 else "" @@ -210,20 +213,56 @@ def _scalar_or_decimal( return f"{x:g}" -def _imag_unit_str(coeff: str) -> str: - return "i" if coeff == "1" else f"{coeff} i" +def _imag_scalar_to_str( + x: float, + output: OutputFormat, + *, + max_denominator: int, + rel_tol: float, + abs_tol: float, +) -> str: + i = r"\mathrm{i}" if output == OutputFormat.LaTeX else "i" + x = abs(x) + + if math.isclose(x, 1.0, rel_tol=rel_tol, abs_tol=abs_tol): + return i + + mkfrac, sqrt = _format_helpers(output) + + frac = Fraction(x).limit_denominator(max_denominator) + if math.isclose(x, float(frac), rel_tol=rel_tol, abs_tol=abs_tol): + num, den = abs(frac.numerator), frac.denominator + if den == 1: + return f"{num}{i}" + if num == 1: + return f"{i}/{den}" if output != OutputFormat.LaTeX else rf"\frac{{{i}}}{{{den}}}" + if output == OutputFormat.LaTeX: + return rf"\frac{{{num}{i}}}{{{den}}}" + return f"{num}{i}/{den}" + + for n in range(1, _MAX_RADICAND + 1): + root = math.sqrt(n) + for d in range(1, max_denominator + 1): + val = root / d + if math.isclose(x, val, rel_tol=rel_tol, abs_tol=abs_tol): + num_str = sqrt(n) if n != 1 else "1" + if d == 1: + return i if num_str == "1" else f"{num_str}{i}" + if num_str == "1": + return f"{i}/{d}" if output != OutputFormat.LaTeX else rf"\frac{{{i}}}{{{d}}}" + if output == OutputFormat.LaTeX: + return rf"\frac{{{num_str}{i}}}{{{d}}}" + return f"{num_str}{i}/{d}" + + return f"{x:g}{i}" def _exp_i_to_str(angle_str: str, output: OutputFormat) -> str: - match output: - case OutputFormat.LaTeX: - return rf"e^{{i{angle_str}}}" - case OutputFormat.Unicode: - return f"e^(i{angle_str})" - case OutputFormat.ASCII: - return f"e^(i*{angle_str})" - case _: - assert_never(output) + if output == OutputFormat.LaTeX: + return rf"\mathrm{{e}}^{{\mathrm{{i}}{angle_str}}}" + if output == OutputFormat.Unicode: + return f"e^(i{angle_str})" + return f"e^(i*{angle_str})" def _cartesian_to_str( @@ -246,14 +285,13 @@ def _cartesian_to_str( abs_tol=abs_tol, ) - imag_coeff = _scalar_or_decimal( - abs(z.imag), + imag_part = _imag_scalar_to_str( + z.imag, output, max_denominator=max_denominator, rel_tol=rel_tol, abs_tol=abs_tol, ) - imag_part = _imag_unit_str(imag_coeff) if real_zero: return f"-{imag_part}" if z.imag < 0 else imag_part @@ -305,6 +343,7 @@ def complex_to_str( str The formatted complex number. """ + _validate_output_format(output) z = complex(z) if math.isclose(z.real, 0.0, rel_tol=rel_tol, abs_tol=abs_tol) and math.isclose( @@ -324,11 +363,11 @@ def complex_to_str( if math.isclose(z.real, 0.0, rel_tol=rel_tol, abs_tol=abs_tol) and math.isclose( z.imag, 1.0, rel_tol=rel_tol, abs_tol=abs_tol ): - return "i" + return r"\mathrm{i}" if output == OutputFormat.LaTeX else "i" if math.isclose(z.real, 0.0, rel_tol=rel_tol, abs_tol=abs_tol) and math.isclose( z.imag, -1.0, rel_tol=rel_tol, abs_tol=abs_tol ): - return "-i" + return r"-\mathrm{i}" if output == OutputFormat.LaTeX else "-i" angle = cmath.phase(z) / pi frac = Fraction(angle).limit_denominator(max_denominator) @@ -340,15 +379,11 @@ def complex_to_str( def _ket_to_str(bits: str, output: OutputFormat) -> str: - match output: - case OutputFormat.LaTeX: - return rf"\ket{{{bits}}}" - case OutputFormat.Unicode: - return f"|{bits}⟩" - case OutputFormat.ASCII: - return f"|{bits}>" - case _: - assert_never(output) + if output == OutputFormat.LaTeX: + return rf"\ket{{{bits}}}" + if output == OutputFormat.Unicode: + return f"|{bits}⟩" + return f"|{bits}>" def _format_statevec_term( @@ -422,6 +457,7 @@ def statevec_to_str( str The formatted statevector as a sum of ket terms. """ + _validate_output_format(output) amplitudes = statevec.to_dict(encoding=encoding, rtol=rtol, atol=atol) terms = [ _format_statevec_term( @@ -477,6 +513,7 @@ def density_matrix_to_str( str The formatted density matrix. """ + _validate_output_format(output) rho = density_matrix.rho nrows, ncols = rho.shape diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index 9a5cedef9..ccf73179e 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -959,6 +959,13 @@ def test_density_matrix_draw_superposition() -> None: ) +def test_density_matrix_draw_single_row() -> None: + dm = DensityMatrix(nqubit=0, data=np.array([[1.0]], dtype=np.complex128)) + assert density_matrix_to_str(dm, OutputFormat.Unicode) == "[[1]]" + assert density_matrix_to_str(dm, OutputFormat.ASCII) == "[[1]]" + assert density_matrix_to_str(dm, OutputFormat.LaTeX) == r"\(\begin{pmatrix} 1 \end{pmatrix}\)" + + def test_density_matrix_draw_4x4() -> None: sv = Statevec(data=np.array([1, 0, 0, 1], dtype=np.complex128) / np.sqrt(2)) dm = DensityMatrix(data=sv) diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 3221b7353..2a029039e 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -15,7 +15,7 @@ from graphix.opengraph import OpenGraph from graphix.parameter import Placeholder from graphix.pattern import Pattern -from graphix.pretty_print import OutputFormat, complex_to_str, pattern_to_str +from graphix.pretty_print import OutputFormat, angle_to_str, complex_to_str, pattern_to_str from graphix.random_objects import rand_circuit from graphix.transpiler import Circuit @@ -216,20 +216,38 @@ def test_xzcorr_str() -> None: (0.70710678, OutputFormat.ASCII, "sqrt(2)/2"), (0.70710678, OutputFormat.Unicode, "√2/2"), (0.70710678, OutputFormat.LaTeX, r"\frac{\sqrt{2}}{2}"), + (-0.70710678, OutputFormat.ASCII, "-sqrt(2)/2"), + (-0.70710678, OutputFormat.Unicode, "-√2/2"), + (-0.70710678, OutputFormat.LaTeX, r"-\frac{\sqrt{2}}{2}"), (math.sqrt(3) / 2, OutputFormat.Unicode, "√3/2"), (math.sqrt(3) / 2, OutputFormat.LaTeX, r"\frac{\sqrt{3}}{2}"), + (-math.sqrt(3) / 2, OutputFormat.Unicode, "-√3/2"), + (-math.sqrt(2), OutputFormat.Unicode, "-√2"), + (-math.sqrt(2), OutputFormat.LaTeX, r"-\sqrt{2}"), (0.5 + 0.8660254j, OutputFormat.ASCII, "e^(i*pi/3)"), (0.5 + 0.8660254j, OutputFormat.Unicode, "e^(iπ/3)"), - (0.5 + 0.8660254j, OutputFormat.LaTeX, r"e^{i\frac{\pi}{3}}"), + (0.5 + 0.8660254j, OutputFormat.LaTeX, r"\mathrm{e}^{\mathrm{i}\frac{\pi}{3}}"), (cmath.exp(1j * math.pi / 3), OutputFormat.Unicode, "e^(iπ/3)"), (0, OutputFormat.Unicode, "0"), (1, OutputFormat.Unicode, "1"), (-1, OutputFormat.Unicode, "-1"), + (2, OutputFormat.Unicode, "2"), + (-3, OutputFormat.ASCII, "-3"), + (2, OutputFormat.LaTeX, "2"), + (2j, OutputFormat.Unicode, "2i"), + (-2j, OutputFormat.LaTeX, r"-2\mathrm{i}"), + (0.25j, OutputFormat.Unicode, "i/4"), + (-0.25j, OutputFormat.LaTeX, r"-\frac{\mathrm{i}}{4}"), (1j, OutputFormat.Unicode, "i"), + (1j, OutputFormat.LaTeX, r"\mathrm{i}"), (-1j, OutputFormat.Unicode, "-i"), - (0.25 + 0.25j, OutputFormat.ASCII, "1/4 + 1/4 i"), - (0.25 + 0.25j, OutputFormat.Unicode, "1/4 + 1/4 i"), - (0.25 + 0.25j, OutputFormat.LaTeX, r"\frac{1}{4} + \frac{1}{4} i"), + (-1j, OutputFormat.LaTeX, r"-\mathrm{i}"), + (0.25 + 0.25j, OutputFormat.ASCII, "1/4 + i/4"), + (0.25 + 0.25j, OutputFormat.Unicode, "1/4 + i/4"), + (0.25 + 0.25j, OutputFormat.LaTeX, r"\frac{1}{4} + \frac{\mathrm{i}}{4}"), + (0.25 - 0.25j, OutputFormat.Unicode, "1/4 - i/4"), + (2 - 3j, OutputFormat.ASCII, "2 - 3i"), + (2 - 3j, OutputFormat.LaTeX, r"2 - 3\mathrm{i}"), ], ) def test_complex_to_str(z: complex, output: OutputFormat, expected: str) -> None: @@ -238,6 +256,28 @@ def test_complex_to_str(z: complex, output: OutputFormat, expected: str) -> None def test_complex_to_str_fallback() -> None: z = 0.123 + 0.456j - assert complex_to_str(z, OutputFormat.ASCII, max_denominator=1) == "0.123 + 0.456 i" - assert complex_to_str(z, OutputFormat.Unicode, max_denominator=1) == "0.123 + 0.456 i" - assert complex_to_str(z, OutputFormat.LaTeX, max_denominator=1) == "0.123 + 0.456 i" + assert complex_to_str(z, OutputFormat.ASCII, max_denominator=1) == "0.123 + 0.456i" + assert complex_to_str(z, OutputFormat.Unicode, max_denominator=1) == "0.123 + 0.456i" + assert complex_to_str(z, OutputFormat.LaTeX, max_denominator=1) == "0.123 + 0.456\\mathrm{i}" + + +@pytest.mark.parametrize( + ("angle", "output", "expected"), + [ + (0.5, OutputFormat.Unicode, "π/2"), + (0.5, OutputFormat.ASCII, "pi/2"), + (0.5, OutputFormat.LaTeX, r"\frac{\pi}{2}"), + ], +) +def test_angle_to_str_fraction(angle: float, output: OutputFormat, expected: str) -> None: + assert angle_to_str(angle, output) == expected + + +@pytest.mark.parametrize("output", list(OutputFormat)) +def test_angle_to_str_radian_fallback(output: OutputFormat) -> None: + angle = 0.123456789 + assert angle_to_str(angle, output) == f"{angle * math.pi:.2f}" + + +def test_angle_to_str_radian_fallback_max_denominator() -> None: + assert angle_to_str(0.7, OutputFormat.Unicode, max_denominator=3) == f"{0.7 * math.pi:.2f}" diff --git a/tests/test_statevec.py b/tests/test_statevec.py index b0df79560..45468e26f 100644 --- a/tests/test_statevec.py +++ b/tests/test_statevec.py @@ -255,6 +255,19 @@ def test_statevec_draw_single_ket(encoding: _ENCODING, output: OutputFormat, exp assert statevec_to_str(sv, output, encoding=encoding) == expected +def test_statevec_draw_single_ket_negative_amplitude() -> None: + sv = Statevec(data=np.array([0, -1], dtype=np.complex128)) + assert statevec_to_str(sv, OutputFormat.Unicode) == "-|1⟩" + assert statevec_to_str(sv, OutputFormat.ASCII) == "-|1>" + assert statevec_to_str(sv, OutputFormat.LaTeX) == r"\(-\ket{1}\)" + + +def test_statevec_draw_empty() -> None: + sv = Statevec(data=BasicStates.PLUS) + assert statevec_to_str(sv, OutputFormat.Unicode, atol=1.0) == "0" + assert statevec_to_str(sv, OutputFormat.LaTeX, atol=1.0) == r"\(0\)" + + def test_statevec_draw_superposition() -> None: sv = Statevec(data=[BasicStates.ZERO, BasicStates.PLUS, BasicStates.MINUS]) assert sv.draw(encoding="MSB", output=OutputFormat.Unicode) == "1/2|000⟩ - 1/2|001⟩ + 1/2|010⟩ - 1/2|011⟩" From b4a54ff16ece86bf129493f05801b100cd6a5991 Mon Sep 17 00:00:00 2001 From: Aidan Sims Date: Thu, 4 Jun 2026 19:02:57 -0700 Subject: [PATCH 4/6] coverage and lint --- graphix/pretty_print.py | 2 +- tests/test_pretty_print.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index 41b1dd7eb..98fb8dbde 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -227,7 +227,7 @@ def _imag_scalar_to_str( if math.isclose(x, 1.0, rel_tol=rel_tol, abs_tol=abs_tol): return i - mkfrac, sqrt = _format_helpers(output) + _, sqrt = _format_helpers(output) frac = Fraction(x).limit_denominator(max_denominator) if math.isclose(x, float(frac), rel_tol=rel_tol, abs_tol=abs_tol): diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 2a029039e..54732823e 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -248,6 +248,11 @@ def test_xzcorr_str() -> None: (0.25 - 0.25j, OutputFormat.Unicode, "1/4 - i/4"), (2 - 3j, OutputFormat.ASCII, "2 - 3i"), (2 - 3j, OutputFormat.LaTeX, r"2 - 3\mathrm{i}"), + (2 + 1j, OutputFormat.Unicode, "2 + i"), + (0.75j, OutputFormat.LaTeX, r"\frac{3\mathrm{i}}{4}"), + (math.sqrt(2) / 2 * 1j, OutputFormat.Unicode, "√2i/2"), + (math.sqrt(2) / 2 * 1j, OutputFormat.LaTeX, r"\frac{\sqrt{2}\mathrm{i}}{2}"), + (2 + math.sqrt(2) / 2 * 1j, OutputFormat.Unicode, "2 + √2i/2"), ], ) def test_complex_to_str(z: complex, output: OutputFormat, expected: str) -> None: @@ -264,6 +269,13 @@ def test_complex_to_str_fallback() -> None: @pytest.mark.parametrize( ("angle", "output", "expected"), [ + (0, OutputFormat.Unicode, "0"), + (0, OutputFormat.ASCII, "0"), + (0, OutputFormat.LaTeX, "0"), + (2, OutputFormat.Unicode, "2π"), + (2, OutputFormat.ASCII, "2pi"), + (2, OutputFormat.LaTeX, r"2\pi"), + (-3, OutputFormat.Unicode, "-3π"), (0.5, OutputFormat.Unicode, "π/2"), (0.5, OutputFormat.ASCII, "pi/2"), (0.5, OutputFormat.LaTeX, r"\frac{\pi}{2}"), From 5b7ad1233611b583f0e3a771b80bd73efcec9388 Mon Sep 17 00:00:00 2001 From: Aidan Sims Date: Sat, 6 Jun 2026 11:54:02 -0700 Subject: [PATCH 5/6] change assert to raise --- graphix/pretty_print.py | 3 ++- tests/test_pretty_print.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index 98fb8dbde..2a1c41479 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -49,7 +49,8 @@ class OutputFormat(Enum): def _validate_output_format(output: OutputFormat) -> None: - assert output in (OutputFormat.ASCII, OutputFormat.LaTeX, OutputFormat.Unicode) + if output not in (OutputFormat.ASCII, OutputFormat.LaTeX, OutputFormat.Unicode): + raise ValueError(f"unsupported output format: {output!r}") def angle_to_str( diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 54732823e..1ec02cc31 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -293,3 +293,8 @@ def test_angle_to_str_radian_fallback(output: OutputFormat) -> None: def test_angle_to_str_radian_fallback_max_denominator() -> None: assert angle_to_str(0.7, OutputFormat.Unicode, max_denominator=3) == f"{0.7 * math.pi:.2f}" + + +def test_validate_output_format() -> None: + with pytest.raises(ValueError, match="unsupported output format"): + complex_to_str(1, object()) # type: ignore[arg-type] From 7df4907848d20e3173a9872480fb7d42e341cc22 Mon Sep 17 00:00:00 2001 From: Aidan Sims Date: Sun, 7 Jun 2026 11:32:33 -0700 Subject: [PATCH 6/6] add parens --- docs/source/visualization.rst | 6 ++++++ graphix/pretty_print.py | 39 +++++++++++++++++++++++++++++++++-- graphix/sim/density_matrix.py | 6 ++++++ graphix/sim/statevec.py | 5 +++++ tests/test_statevec.py | 14 +++++++++++++ 5 files changed, 68 insertions(+), 2 deletions(-) diff --git a/docs/source/visualization.rst b/docs/source/visualization.rst index a48628c6c..f30097b01 100644 --- a/docs/source/visualization.rst +++ b/docs/source/visualization.rst @@ -17,6 +17,12 @@ If flow or gflow exist, the tool take them into account and show the information This modules provides functions to format patterns and flows. +``complex_to_str``, ``statevec_to_str``, and ``density_matrix_to_str`` format +concrete numeric amplitudes and matrix elements. They do not support symbolic +parameters such as :class:`~graphix.parameter.Placeholder`; substitute +parameters before calling these functions, or use ``str(...)`` on the +statevector or density matrix object for a raw representation. + .. currentmodule:: graphix.pretty_print .. autoclass:: OutputFormat diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index 2a1c41479..71fcc9e76 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -343,6 +343,11 @@ def complex_to_str( ------- str The formatted complex number. + + Notes + ----- + This function expects a concrete numeric value. Symbolic expressions such as + :class:`~graphix.parameter.Placeholder` are not supported. """ _validate_output_format(output) z = complex(z) @@ -387,6 +392,18 @@ def _ket_to_str(bits: str, output: OutputFormat) -> str: return f"|{bits}>" +def _coeff_needs_parens(coeff: str) -> bool: + return " + " in coeff or " - " in coeff + + +def _coeff_ket_body(coeff: str, ket_str: str, output: OutputFormat) -> str: + if _coeff_needs_parens(coeff): + if output == OutputFormat.LaTeX: + return rf"\left({coeff}\right){ket_str}" + return f"({coeff}){ket_str}" + return f"{coeff}{ket_str}" + + def _format_statevec_term( amp: complex, ket: str, @@ -402,9 +419,11 @@ def _format_statevec_term( return "+", ket_str if coeff == "-1": return "-", ket_str + if _coeff_needs_parens(coeff): + return "+", _coeff_ket_body(coeff, ket_str, output) if coeff.startswith("-"): - return "-", f"{coeff[1:]}{ket_str}" - return "+", f"{coeff}{ket_str}" + return "-", _coeff_ket_body(coeff[1:], ket_str, output) + return "+", _coeff_ket_body(coeff, ket_str, output) def _join_statevec_terms(terms: list[tuple[str, str]]) -> str: @@ -457,6 +476,14 @@ def statevec_to_str( ------- str The formatted statevector as a sum of ket terms. + + Notes + ----- + This function formats concrete numeric amplitudes only. Symbolic or + parametric values (for example :class:`~graphix.parameter.Placeholder`) are + not supported. Substitute parameters with :meth:`~graphix.sim.statevec.Statevec.subs` + or :meth:`~graphix.sim.statevec.Statevec.xreplace` before calling this + function, or use ``str(statevec)`` for a raw representation. """ _validate_output_format(output) amplitudes = statevec.to_dict(encoding=encoding, rtol=rtol, atol=atol) @@ -513,6 +540,14 @@ def density_matrix_to_str( ------- str The formatted density matrix. + + Notes + ----- + This function formats concrete numeric entries only. Symbolic or parametric + values (for example :class:`~graphix.parameter.Placeholder`) are not + supported. Substitute parameters with :meth:`~graphix.sim.density_matrix.DensityMatrix.subs` + or :meth:`~graphix.sim.density_matrix.DensityMatrix.xreplace` before calling + this function, or use ``str(density_matrix)`` for a raw representation. """ _validate_output_format(output) rho = density_matrix.rho diff --git a/graphix/sim/density_matrix.py b/graphix/sim/density_matrix.py index 1471f20d6..cef35af77 100644 --- a/graphix/sim/density_matrix.py +++ b/graphix/sim/density_matrix.py @@ -378,6 +378,12 @@ def draw( str The formatted density matrix. + Notes + ----- + Requires concrete numeric entries. For parametric density matrices, + substitute with :meth:`subs` or :meth:`xreplace` first, or use + ``str(self)``. + See Also -------- :func:`graphix.pretty_print.density_matrix_to_str` diff --git a/graphix/sim/statevec.py b/graphix/sim/statevec.py index c67860de2..3db293569 100644 --- a/graphix/sim/statevec.py +++ b/graphix/sim/statevec.py @@ -513,6 +513,11 @@ def draw( str The formatted statevector as a sum of ket terms. + Notes + ----- + Requires concrete numeric amplitudes. For parametric states, substitute + with :meth:`subs` or :meth:`xreplace` first, or use ``str(self)``. + See Also -------- :func:`graphix.pretty_print.statevec_to_str` diff --git a/tests/test_statevec.py b/tests/test_statevec.py index 45468e26f..748d0e9f6 100644 --- a/tests/test_statevec.py +++ b/tests/test_statevec.py @@ -277,6 +277,20 @@ def test_statevec_draw_superposition() -> None: ) +def test_statevec_draw_mixed_amplitude() -> None: + amp = 0.25 + 0.25j + other = np.sqrt(1 - abs(amp) ** 2) + sv = Statevec(data=np.array([amp, other], dtype=np.complex128)) + assert statevec_to_str(sv, OutputFormat.Unicode) == "(1/4 + i/4)|0⟩ + 0.935414|1⟩" + assert statevec_to_str(sv, OutputFormat.LaTeX) == ( + r"\(\left(\frac{1}{4} + \frac{\mathrm{i}}{4}\right)\ket{0} + 0.935414\ket{1}\)" + ) + + neg_amp = -0.25 - 0.25j + sv_neg = Statevec(data=np.array([neg_amp, other], dtype=np.complex128)) + assert statevec_to_str(sv_neg, OutputFormat.Unicode) == "(-1/4 - i/4)|0⟩ + 0.935414|1⟩" + + def test_normalize() -> None: statevec = Statevec(nqubit=1, data=BasicStates.PLUS) statevec.remove_qubit(0)