diff --git a/docs/source/visualization.rst b/docs/source/visualization.rst index 489838cbd..f30097b01 100644 --- a/docs/source/visualization.rst +++ b/docs/source/visualization.rst @@ -17,12 +17,24 @@ If flow or gflow exist, the tool take them into account and show the information This modules provides functions to format patterns and flows. +``complex_to_str``, ``statevec_to_str``, and ``density_matrix_to_str`` format +concrete numeric amplitudes and matrix elements. They do not support symbolic +parameters such as :class:`~graphix.parameter.Placeholder`; substitute +parameters before calling these functions, or use ``str(...)`` on the +statevector or density matrix object for a raw representation. + .. currentmodule:: graphix.pretty_print .. autoclass:: OutputFormat .. autofunction:: angle_to_str +.. autofunction:: complex_to_str + +.. autofunction:: statevec_to_str + +.. autofunction:: density_matrix_to_str + .. autofunction:: command_to_str .. autofunction:: pattern_to_str diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index b2bd72e8f..71fcc9e76 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -2,30 +2,42 @@ from __future__ import annotations +import cmath import enum import math import string from enum import Enum from fractions import Fraction from math import pi -from typing import TYPE_CHECKING, SupportsFloat +from typing import TYPE_CHECKING, Literal, SupportsComplex, SupportsFloat # `assert_never` introduced in Python 3.11 from typing_extensions import assert_never from graphix import command -from graphix.fundamentals import AbstractMeasurement, Axis, Plane, Sign, angle_to_rad, rad_to_angle +from graphix.fundamentals import ( + AbstractMeasurement, + Axis, + Plane, + Sign, + angle_to_rad, + rad_to_angle, +) from graphix.measurements import BlochMeasurement, PauliMeasurement from graphix.parameter import AffineExpression if TYPE_CHECKING: - from collections.abc import Container, Iterable, Mapping, Sequence + from collections.abc import Callable, Container, Iterable, Mapping, Sequence from collections.abc import Set as AbstractSet from graphix.command import Node from graphix.flow.core import PauliFlow, XZCorrections from graphix.fundamentals import Angle from graphix.pattern import Pattern + from graphix.sim.density_matrix import DensityMatrix + from graphix.sim.statevec import Statevec + +_ENCODING = Literal["LSB", "MSB"] class OutputFormat(Enum): @@ -36,8 +48,17 @@ class OutputFormat(Enum): Unicode = enum.auto() +def _validate_output_format(output: OutputFormat) -> None: + if output not in (OutputFormat.ASCII, OutputFormat.LaTeX, OutputFormat.Unicode): + raise ValueError(f"unsupported output format: {output!r}") + + def angle_to_str( - angle: Angle, output: OutputFormat, max_denominator: int = 1000, multiplication_sign: bool = False + angle: Angle, + output: OutputFormat, + max_denominator: int = 1000, + multiplication_sign: bool = False, + abs_tol: float = 0.0, ) -> str: r""" Return a string representation of an angle given in units of π. @@ -61,18 +82,19 @@ def angle_to_str( ``2×π`` in Unicode, ``2 \times \pi`` in LaTeX, and ``2*pi`` in ASCII. If ``False``, the multiplication sign is implicit: ``2π`` in Unicode, ``2\pi`` in LaTeX, ``2pi`` in ASCII. + abs_tol : float, optional + Absolute tolerance passed to :func:`math.isclose` (default: ``0.0``). Returns ------- str The formatted angle. """ + _validate_output_format(output) frac = Fraction(angle).limit_denominator(max_denominator) - if not math.isclose(angle, float(frac)): - rad = angle_to_rad(angle) - - return f"{rad}" + if not math.isclose(angle, float(frac), abs_tol=abs_tol): + return f"{angle_to_rad(angle):.2f}" num, den = frac.numerator, frac.denominator sign = "-" if num < 0 else "" @@ -110,6 +132,460 @@ def mkfrac(num: str, den: str) -> str: return f"{sign}{mkfrac(num_str, den_str)}" +_MAX_RADICAND = 10 + + +def _format_helpers( + output: OutputFormat, +) -> tuple[Callable[[str, str], str], Callable[[int], str]]: + if output == OutputFormat.LaTeX: + + def mkfrac(num: str, den: str) -> str: + return rf"\frac{{{num}}}{{{den}}}" + + def sqrt(n: int) -> str: + return rf"\sqrt{{{n}}}" + + elif output == OutputFormat.Unicode: + + def mkfrac(num: str, den: str) -> str: + return f"{num}/{den}" + + def sqrt(n: int) -> str: + return f"√{n}" + + else: + + def mkfrac(num: str, den: str) -> str: + return f"{num}/{den}" + + def sqrt(n: int) -> str: + return f"sqrt({n})" + + return mkfrac, sqrt + + +def _real_scalar_to_str( + x: float, + output: OutputFormat, + *, + max_denominator: int = 1000, + rel_tol: float = 1e-9, + abs_tol: float = 1e-8, + max_radicand: int = _MAX_RADICAND, +) -> str | None: + mkfrac, sqrt = _format_helpers(output) + + frac = Fraction(x).limit_denominator(max_denominator) + if math.isclose(x, float(frac), rel_tol=rel_tol, abs_tol=abs_tol): + num, den = frac.numerator, frac.denominator + sign = "-" if num < 0 else "" + num = abs(num) + if den == 1: + return f"{sign}{num}" + return f"{sign}{mkfrac(str(num), str(den))}" + + for n in range(1, max_radicand + 1): + root = math.sqrt(n) + for d in range(1, max_denominator + 1): + val = root / d + if math.isclose(x, val, rel_tol=rel_tol, abs_tol=abs_tol): + num_str = sqrt(n) if n != 1 else "1" + return num_str if d == 1 else mkfrac(num_str, str(d)) + if math.isclose(x, -val, rel_tol=rel_tol, abs_tol=abs_tol): + num_str = sqrt(n) if n != 1 else "1" + formatted = num_str if d == 1 else mkfrac(num_str, str(d)) + return f"-{formatted}" + + return None + + +def _scalar_or_decimal( + x: float, + output: OutputFormat, + *, + max_denominator: int, + rel_tol: float, + abs_tol: float, +) -> str: + result = _real_scalar_to_str(x, output, max_denominator=max_denominator, rel_tol=rel_tol, abs_tol=abs_tol) + if result: + return result + return f"{x:g}" + + +def _imag_scalar_to_str( + x: float, + output: OutputFormat, + *, + max_denominator: int, + rel_tol: float, + abs_tol: float, +) -> str: + i = r"\mathrm{i}" if output == OutputFormat.LaTeX else "i" + x = abs(x) + + if math.isclose(x, 1.0, rel_tol=rel_tol, abs_tol=abs_tol): + return i + + _, sqrt = _format_helpers(output) + + frac = Fraction(x).limit_denominator(max_denominator) + if math.isclose(x, float(frac), rel_tol=rel_tol, abs_tol=abs_tol): + num, den = abs(frac.numerator), frac.denominator + if den == 1: + return f"{num}{i}" + if num == 1: + return f"{i}/{den}" if output != OutputFormat.LaTeX else rf"\frac{{{i}}}{{{den}}}" + if output == OutputFormat.LaTeX: + return rf"\frac{{{num}{i}}}{{{den}}}" + return f"{num}{i}/{den}" + + for n in range(1, _MAX_RADICAND + 1): + root = math.sqrt(n) + for d in range(1, max_denominator + 1): + val = root / d + if math.isclose(x, val, rel_tol=rel_tol, abs_tol=abs_tol): + num_str = sqrt(n) if n != 1 else "1" + if d == 1: + return i if num_str == "1" else f"{num_str}{i}" + if num_str == "1": + return f"{i}/{d}" if output != OutputFormat.LaTeX else rf"\frac{{{i}}}{{{d}}}" + if output == OutputFormat.LaTeX: + return rf"\frac{{{num_str}{i}}}{{{d}}}" + return f"{num_str}{i}/{d}" + + return f"{x:g}{i}" + + +def _exp_i_to_str(angle_str: str, output: OutputFormat) -> str: + if output == OutputFormat.LaTeX: + return rf"\mathrm{{e}}^{{\mathrm{{i}}{angle_str}}}" + if output == OutputFormat.Unicode: + return f"e^(i{angle_str})" + return f"e^(i*{angle_str})" + + +def _cartesian_to_str( + z: complex, + output: OutputFormat, + *, + max_denominator: int, + rel_tol: float, + abs_tol: float, +) -> str: + real_zero = math.isclose(z.real, 0.0, rel_tol=rel_tol, abs_tol=abs_tol) + imag_zero = math.isclose(z.imag, 0.0, rel_tol=rel_tol, abs_tol=abs_tol) + + if imag_zero: + return _scalar_or_decimal( + z.real, + output, + max_denominator=max_denominator, + rel_tol=rel_tol, + abs_tol=abs_tol, + ) + + imag_part = _imag_scalar_to_str( + z.imag, + output, + max_denominator=max_denominator, + rel_tol=rel_tol, + abs_tol=abs_tol, + ) + + if real_zero: + return f"-{imag_part}" if z.imag < 0 else imag_part + + real_str = _scalar_or_decimal( + z.real, + output, + max_denominator=max_denominator, + rel_tol=rel_tol, + abs_tol=abs_tol, + ) + if z.imag >= 0: + return f"{real_str} + {imag_part}" + return f"{real_str} - {imag_part}" + + +def complex_to_str( + z: complex | SupportsComplex, + output: OutputFormat, + *, + max_denominator: int = 1000, + rel_tol: float = 1e-9, + abs_tol: float = 1e-8, +) -> str: + r""" + Return a string representation of a complex number. + + Common values are rendered symbolically: + + - rational reals such as ``0.25`` as ``1/4``, + - radical rationals such as ``0.70710678`` as ``√2/2``, + - unit-modulus values such as ``0.5 + 0.8660254j`` as ``e^(iπ/3)``. + + Parameters + ---------- + z : complex + The complex number to format. + output : OutputFormat + Desired formatting style: Unicode, LaTeX, or ASCII. + max_denominator : int, optional + Maximum denominator for detecting simple fractions and radical forms (default: 1000). + rel_tol : float, optional + Relative tolerance passed to :func:`math.isclose` (default: ``1e-9``). + abs_tol : float, optional + Absolute tolerance passed to :func:`math.isclose` (default: ``1e-8``). + + Returns + ------- + str + The formatted complex number. + + Notes + ----- + This function expects a concrete numeric value. Symbolic expressions such as + :class:`~graphix.parameter.Placeholder` are not supported. + """ + _validate_output_format(output) + z = complex(z) + + if math.isclose(z.real, 0.0, rel_tol=rel_tol, abs_tol=abs_tol) and math.isclose( + z.imag, 0.0, rel_tol=rel_tol, abs_tol=abs_tol + ): + return "0" + + if math.isclose(abs(z), 1.0, rel_tol=rel_tol, abs_tol=abs_tol): + if math.isclose(z.real, 1.0, rel_tol=rel_tol, abs_tol=abs_tol) and math.isclose( + z.imag, 0.0, rel_tol=rel_tol, abs_tol=abs_tol + ): + return "1" + if math.isclose(z.real, -1.0, rel_tol=rel_tol, abs_tol=abs_tol) and math.isclose( + z.imag, 0.0, rel_tol=rel_tol, abs_tol=abs_tol + ): + return "-1" + if math.isclose(z.real, 0.0, rel_tol=rel_tol, abs_tol=abs_tol) and math.isclose( + z.imag, 1.0, rel_tol=rel_tol, abs_tol=abs_tol + ): + return r"\mathrm{i}" if output == OutputFormat.LaTeX else "i" + if math.isclose(z.real, 0.0, rel_tol=rel_tol, abs_tol=abs_tol) and math.isclose( + z.imag, -1.0, rel_tol=rel_tol, abs_tol=abs_tol + ): + return r"-\mathrm{i}" if output == OutputFormat.LaTeX else "-i" + + angle = cmath.phase(z) / pi + frac = Fraction(angle).limit_denominator(max_denominator) + if math.isclose(angle, float(frac), rel_tol=rel_tol, abs_tol=abs_tol): + angle_str = angle_to_str(angle, output, max_denominator=max_denominator, abs_tol=abs_tol) + return _exp_i_to_str(angle_str, output) + + return _cartesian_to_str(z, output, max_denominator=max_denominator, rel_tol=rel_tol, abs_tol=abs_tol) + + +def _ket_to_str(bits: str, output: OutputFormat) -> str: + if output == OutputFormat.LaTeX: + return rf"\ket{{{bits}}}" + if output == OutputFormat.Unicode: + return f"|{bits}⟩" + return f"|{bits}>" + + +def _coeff_needs_parens(coeff: str) -> bool: + return " + " in coeff or " - " in coeff + + +def _coeff_ket_body(coeff: str, ket_str: str, output: OutputFormat) -> str: + if _coeff_needs_parens(coeff): + if output == OutputFormat.LaTeX: + return rf"\left({coeff}\right){ket_str}" + return f"({coeff}){ket_str}" + return f"{coeff}{ket_str}" + + +def _format_statevec_term( + amp: complex, + ket: str, + output: OutputFormat, + *, + max_denominator: int, + rel_tol: float, + abs_tol: float, +) -> tuple[str, str]: + coeff = complex_to_str(amp, output, max_denominator=max_denominator, rel_tol=rel_tol, abs_tol=abs_tol) + ket_str = _ket_to_str(ket, output) + if coeff == "1": + return "+", ket_str + if coeff == "-1": + return "-", ket_str + if _coeff_needs_parens(coeff): + return "+", _coeff_ket_body(coeff, ket_str, output) + if coeff.startswith("-"): + return "-", _coeff_ket_body(coeff[1:], ket_str, output) + return "+", _coeff_ket_body(coeff, ket_str, output) + + +def _join_statevec_terms(terms: list[tuple[str, str]]) -> str: + if not terms: + return "0" + sign, body = terms[0] + result = f"-{body}" if sign == "-" else body + for sign, body in terms[1:]: + result += f" - {body}" if sign == "-" else f" + {body}" + return result + + +def statevec_to_str( + statevec: Statevec, + output: OutputFormat, + encoding: _ENCODING = "MSB", + *, + rtol: float = 0.0, + atol: float = 1e-8, + max_denominator: int = 1000, + rel_tol: float = 1e-9, + abs_tol: float = 1e-8, +) -> str: + r""" + Return a string representation of a statevector in ket notation. + + Uses :meth:`graphix.sim.statevec.Statevec.to_dict` to obtain non-zero amplitudes, + and formats each amplitude with :func:`complex_to_str`. + + Parameters + ---------- + statevec : Statevec + The statevector to format. + output : OutputFormat + Desired formatting style: Unicode, LaTeX, or ASCII. + encoding : Literal["LSB", "MSB"], default="MSB" + Encoding for the basis kets. See :meth:`graphix.sim.statevec.Statevec.to_dict`. + rtol : float, default=0.0 + Relative tolerance for filtering zero amplitudes, passed to :meth:`Statevec.to_dict`. + atol : float, default=1e-8 + Absolute tolerance for filtering zero amplitudes, passed to :meth:`Statevec.to_dict`. + max_denominator : int, optional + Maximum denominator for detecting simple amplitudes (default: 1000). + rel_tol : float, optional + Relative tolerance passed to :func:`complex_to_str` (default: ``1e-9``). + abs_tol : float, optional + Absolute tolerance passed to :func:`complex_to_str` (default: ``1e-8``). + + Returns + ------- + str + The formatted statevector as a sum of ket terms. + + Notes + ----- + This function formats concrete numeric amplitudes only. Symbolic or + parametric values (for example :class:`~graphix.parameter.Placeholder`) are + not supported. Substitute parameters with :meth:`~graphix.sim.statevec.Statevec.subs` + or :meth:`~graphix.sim.statevec.Statevec.xreplace` before calling this + function, or use ``str(statevec)`` for a raw representation. + """ + _validate_output_format(output) + amplitudes = statevec.to_dict(encoding=encoding, rtol=rtol, atol=atol) + terms = [ + _format_statevec_term( + complex(amp), + ket, + output, + max_denominator=max_denominator, + rel_tol=rel_tol, + abs_tol=abs_tol, + ) + for ket, amp in amplitudes.items() + ] + result = _join_statevec_terms(terms) + if output == OutputFormat.LaTeX: + return f"\\({result}\\)" + return result + + +def density_matrix_to_str( + density_matrix: DensityMatrix, + output: OutputFormat, + *, + rtol: float = 0.0, + atol: float = 1e-8, + max_denominator: int = 1000, + rel_tol: float = 1e-9, + abs_tol: float = 1e-8, +) -> str: + r""" + Return a string representation of a density matrix. + + Formats each matrix element with :func:`complex_to_str`. + + Parameters + ---------- + density_matrix : DensityMatrix + The density matrix to format. + output : OutputFormat + Desired formatting style: Unicode, LaTeX, or ASCII. + rtol : float, default=0.0 + Relative tolerance for displaying negligible elements as ``0``. + atol : float, default=1e-8 + Absolute tolerance for displaying negligible elements as ``0``. + max_denominator : int, optional + Maximum denominator for detecting simple matrix elements (default: 1000). + rel_tol : float, optional + Relative tolerance passed to :func:`complex_to_str` (default: ``1e-9``). + abs_tol : float, optional + Absolute tolerance passed to :func:`complex_to_str` (default: ``1e-8``). + + Returns + ------- + str + The formatted density matrix. + + Notes + ----- + This function formats concrete numeric entries only. Symbolic or parametric + values (for example :class:`~graphix.parameter.Placeholder`) are not + supported. Substitute parameters with :meth:`~graphix.sim.density_matrix.DensityMatrix.subs` + or :meth:`~graphix.sim.density_matrix.DensityMatrix.xreplace` before calling + this function, or use ``str(density_matrix)`` for a raw representation. + """ + _validate_output_format(output) + rho = density_matrix.rho + nrows, ncols = rho.shape + + def format_cell(value: complex) -> str: + if math.isclose(abs(value), 0.0, rel_tol=rtol, abs_tol=atol): + return "0" + return complex_to_str( + value, + output, + max_denominator=max_denominator, + rel_tol=rel_tol, + abs_tol=abs_tol, + ) + + cells = [[format_cell(complex(rho[i, j])) for j in range(ncols)] for i in range(nrows)] + col_widths = [max(len(cells[i][j]) for i in range(nrows)) for j in range(ncols)] + + if output == OutputFormat.LaTeX: + rows = [" & ".join(cell.rjust(col_widths[j]) for j, cell in enumerate(row)) for row in cells] + body = r" \\ ".join(rows) + return rf"\(\begin{{pmatrix}} {body} \end{{pmatrix}}\)" + + lines: list[str] = [] + for i, row in enumerate(cells): + inner = ", ".join(cell.rjust(col_widths[j]) for j, cell in enumerate(row)) + if nrows == 1: + lines.append(f"[[{inner}]]") + elif i == 0: + lines.append(f"[[{inner}],") + elif i == nrows - 1: + lines.append(f" [{inner}]]") + else: + lines.append(f" [{inner}],") + return "\n".join(lines) + + def domain_to_str(domain: set[Node]) -> str: """Return the string representation of a domain.""" return f"{{{','.join(str(node) for node in domain)}}}" @@ -299,7 +775,10 @@ def set_to_str(objects: Iterable[object], output: OutputFormat) -> str: def correction_function_to_str( - correction_function: Mapping[int, AbstractSet[int]], cf_name: str, output: OutputFormat, multiline: bool = False + correction_function: Mapping[int, AbstractSet[int]], + cf_name: str, + output: OutputFormat, + multiline: bool = False, ) -> str: """Convert a correction function mapping to a formatted string representation. @@ -411,7 +890,11 @@ def flow_to_str(flow: PauliFlow[AbstractMeasurement], output: OutputFormat, mult ) -def xzcorr_to_str(xzcorr: XZCorrections[AbstractMeasurement], output: OutputFormat, multiline: bool = False) -> str: +def xzcorr_to_str( + xzcorr: XZCorrections[AbstractMeasurement], + output: OutputFormat, + multiline: bool = False, +) -> str: """Convert an XZCorrections object to a formatted string representation. Parameters diff --git a/graphix/sim/density_matrix.py b/graphix/sim/density_matrix.py index cb5570e2a..cef35af77 100644 --- a/graphix/sim/density_matrix.py +++ b/graphix/sim/density_matrix.py @@ -19,6 +19,7 @@ from graphix import parameter from graphix.channels import KrausChannel from graphix.parameter import Expression, ExpressionOrFloat, ExpressionOrSupportsComplex +from graphix.pretty_print import OutputFormat, density_matrix_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, matmul, outer, tensordot, vdot from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec from graphix.states import BasicStates, State @@ -354,6 +355,41 @@ def flatten(self) -> Matrix: """Return flattened density matrix.""" return self.rho.flatten() + def draw( + self, + output: OutputFormat = OutputFormat.Unicode, + *, + rtol: float = 0.0, + atol: float = 1e-8, + ) -> str: + r"""Return a pretty-printed string representation of the density matrix. + + Parameters + ---------- + output : OutputFormat, default=OutputFormat.Unicode + Desired formatting style: Unicode, LaTeX, or ASCII. + rtol : float, default=0.0 + Relative tolerance for displaying negligible elements as ``0``. + atol : float, default=1e-8 + Absolute tolerance for displaying negligible elements as ``0``. + + Returns + ------- + str + The formatted density matrix. + + Notes + ----- + Requires concrete numeric entries. For parametric density matrices, + substitute with :meth:`subs` or :meth:`xreplace` first, or use + ``str(self)``. + + See Also + -------- + :func:`graphix.pretty_print.density_matrix_to_str` + """ + return density_matrix_to_str(self, output, rtol=rtol, atol=atol) + def apply_channel(self, channel: KrausChannel, qargs: Sequence[int]) -> None: """Apply a channel to a density matrix. diff --git a/graphix/sim/statevec.py b/graphix/sim/statevec.py index 1f7dd23e6..3db293569 100644 --- a/graphix/sim/statevec.py +++ b/graphix/sim/statevec.py @@ -16,6 +16,7 @@ from graphix import parameter, states from graphix.parameter import Expression, ExpressionOrSupportsComplex, check_expression_or_float +from graphix.pretty_print import OutputFormat, statevec_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, tensordot from graphix.states import BasicStates @@ -486,6 +487,43 @@ def to_dict( """ return self._to_dict_map(lambda x: x, encoding, rtol=rtol, atol=atol) + def draw( + self, + encoding: _ENCODING = "MSB", + output: OutputFormat = OutputFormat.Unicode, + *, + rtol: float = 0.0, + atol: float = 1e-8, + ) -> str: + r"""Return a pretty-printed string representation of the statevector in ket notation. + + Parameters + ---------- + encoding : Literal["LSB", "MSB"], default="MSB" + Encoding for the basis kets. See :meth:`to_dict` for additional information. + output : OutputFormat, default=OutputFormat.Unicode + Desired formatting style: Unicode, LaTeX, or ASCII. + rtol : float, default=0.0 + Relative tolerance for filtering zero amplitudes. See :meth:`to_dict`. + atol : float, default=1e-8 + Absolute tolerance for filtering zero amplitudes. See :meth:`to_dict`. + + Returns + ------- + str + The formatted statevector as a sum of ket terms. + + Notes + ----- + Requires concrete numeric amplitudes. For parametric states, substitute + with :meth:`subs` or :meth:`xreplace` first, or use ``str(self)``. + + See Also + -------- + :func:`graphix.pretty_print.statevec_to_str` + """ + return statevec_to_str(self, output, encoding, rtol=rtol, atol=atol) + def to_prob_dict( self, encoding: _ENCODING = "MSB", *, rtol: float = 0.0, atol: float = 1e-8 ) -> dict[str, np.object_ | np.float64]: diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index 7c67909c5..ccf73179e 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -15,6 +15,7 @@ from graphix.channels import KrausChannel, dephasing_channel, depolarising_channel from graphix.fundamentals import ANGLE_PI, Plane from graphix.ops import Ops +from graphix.pretty_print import OutputFormat, density_matrix_to_str from graphix.sim.density_matrix import DensityMatrix, DensityMatrixBackend from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec from graphix.simulator import DefaultMeasureMethod @@ -929,3 +930,56 @@ def test_measure(self, outcome: Outcome) -> None: else np.kron(np.array([[1, 0], [0, 0]]), np.ones((2, 2)) / 2) ) assert np.allclose(backend.state.rho, expected_matrix) + + +@pytest.mark.parametrize( + ("output", "expected"), + [ + (OutputFormat.Unicode, "[[0, 0, 0, 0],\n [0, 1, 0, 0],\n [0, 0, 0, 0],\n [0, 0, 0, 0]]"), + (OutputFormat.ASCII, "[[0, 0, 0, 0],\n [0, 1, 0, 0],\n [0, 0, 0, 0],\n [0, 0, 0, 0]]"), + ( + OutputFormat.LaTeX, + r"\(\begin{pmatrix} 0 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 \end{pmatrix}\)", + ), + ], +) +def test_density_matrix_draw_pure_state(output: OutputFormat, expected: str) -> None: + dm = DensityMatrix(data=[BasicStates.ZERO, BasicStates.ONE]) + assert dm.rho.shape == (4, 4) + assert dm.draw(output=output) == expected + assert density_matrix_to_str(dm, output) == expected + + +def test_density_matrix_draw_superposition() -> None: + dm = DensityMatrix(data=BasicStates.PLUS) + assert dm.draw(output=OutputFormat.Unicode) == "[[1/2, 1/2],\n [1/2, 1/2]]" + assert density_matrix_to_str(dm, OutputFormat.Unicode) == "[[1/2, 1/2],\n [1/2, 1/2]]" + assert dm.draw(output=OutputFormat.LaTeX) == ( + r"\(\begin{pmatrix} \frac{1}{2} & \frac{1}{2} \\ \frac{1}{2} & \frac{1}{2} \end{pmatrix}\)" + ) + + +def test_density_matrix_draw_single_row() -> None: + dm = DensityMatrix(nqubit=0, data=np.array([[1.0]], dtype=np.complex128)) + assert density_matrix_to_str(dm, OutputFormat.Unicode) == "[[1]]" + assert density_matrix_to_str(dm, OutputFormat.ASCII) == "[[1]]" + assert density_matrix_to_str(dm, OutputFormat.LaTeX) == r"\(\begin{pmatrix} 1 \end{pmatrix}\)" + + +def test_density_matrix_draw_4x4() -> None: + sv = Statevec(data=np.array([1, 0, 0, 1], dtype=np.complex128) / np.sqrt(2)) + dm = DensityMatrix(data=sv) + assert dm.rho.shape == (4, 4) + assert ( + dm.draw(output=OutputFormat.Unicode) + == "[[1/2, 0, 0, 1/2],\n [ 0, 0, 0, 0],\n [ 0, 0, 0, 0],\n [1/2, 0, 0, 1/2]]" + ) + assert density_matrix_to_str(dm, OutputFormat.Unicode) == ( + "[[1/2, 0, 0, 1/2],\n [ 0, 0, 0, 0],\n [ 0, 0, 0, 0],\n [1/2, 0, 0, 1/2]]" + ) + assert dm.draw(output=OutputFormat.LaTeX) == ( + r"\(\begin{pmatrix} \frac{1}{2} & 0 & 0 & \frac{1}{2} \\ " + r" 0 & 0 & 0 & 0 \\ " + r" 0 & 0 & 0 & 0 \\ " + r"\frac{1}{2} & 0 & 0 & \frac{1}{2} \end{pmatrix}\)" + ) diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 28dedce55..1ec02cc31 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -1,5 +1,7 @@ from __future__ import annotations +import cmath +import math from typing import TYPE_CHECKING import networkx as nx @@ -13,7 +15,7 @@ from graphix.opengraph import OpenGraph from graphix.parameter import Placeholder from graphix.pattern import Pattern -from graphix.pretty_print import OutputFormat, pattern_to_str +from graphix.pretty_print import OutputFormat, angle_to_str, complex_to_str, pattern_to_str from graphix.random_objects import rand_circuit from graphix.transpiler import Circuit @@ -202,3 +204,97 @@ def test_xzcorr_str() -> None: str(flow) == "x(3) = {5}, x(4) = {6}, x(1) = {3}, x(2) = {4}; z(1) = {4, 5}, z(2) = {3, 6}; {1, 2} < {3, 4} < {5, 6}" ) + + +@pytest.mark.parametrize( + ("z", "output", "expected"), + [ + (0.25, OutputFormat.ASCII, "1/4"), + (0.25, OutputFormat.Unicode, "1/4"), + (0.25, OutputFormat.LaTeX, r"\frac{1}{4}"), + (0.25 + 0j, OutputFormat.Unicode, "1/4"), + (0.70710678, OutputFormat.ASCII, "sqrt(2)/2"), + (0.70710678, OutputFormat.Unicode, "√2/2"), + (0.70710678, OutputFormat.LaTeX, r"\frac{\sqrt{2}}{2}"), + (-0.70710678, OutputFormat.ASCII, "-sqrt(2)/2"), + (-0.70710678, OutputFormat.Unicode, "-√2/2"), + (-0.70710678, OutputFormat.LaTeX, r"-\frac{\sqrt{2}}{2}"), + (math.sqrt(3) / 2, OutputFormat.Unicode, "√3/2"), + (math.sqrt(3) / 2, OutputFormat.LaTeX, r"\frac{\sqrt{3}}{2}"), + (-math.sqrt(3) / 2, OutputFormat.Unicode, "-√3/2"), + (-math.sqrt(2), OutputFormat.Unicode, "-√2"), + (-math.sqrt(2), OutputFormat.LaTeX, r"-\sqrt{2}"), + (0.5 + 0.8660254j, OutputFormat.ASCII, "e^(i*pi/3)"), + (0.5 + 0.8660254j, OutputFormat.Unicode, "e^(iπ/3)"), + (0.5 + 0.8660254j, OutputFormat.LaTeX, r"\mathrm{e}^{\mathrm{i}\frac{\pi}{3}}"), + (cmath.exp(1j * math.pi / 3), OutputFormat.Unicode, "e^(iπ/3)"), + (0, OutputFormat.Unicode, "0"), + (1, OutputFormat.Unicode, "1"), + (-1, OutputFormat.Unicode, "-1"), + (2, OutputFormat.Unicode, "2"), + (-3, OutputFormat.ASCII, "-3"), + (2, OutputFormat.LaTeX, "2"), + (2j, OutputFormat.Unicode, "2i"), + (-2j, OutputFormat.LaTeX, r"-2\mathrm{i}"), + (0.25j, OutputFormat.Unicode, "i/4"), + (-0.25j, OutputFormat.LaTeX, r"-\frac{\mathrm{i}}{4}"), + (1j, OutputFormat.Unicode, "i"), + (1j, OutputFormat.LaTeX, r"\mathrm{i}"), + (-1j, OutputFormat.Unicode, "-i"), + (-1j, OutputFormat.LaTeX, r"-\mathrm{i}"), + (0.25 + 0.25j, OutputFormat.ASCII, "1/4 + i/4"), + (0.25 + 0.25j, OutputFormat.Unicode, "1/4 + i/4"), + (0.25 + 0.25j, OutputFormat.LaTeX, r"\frac{1}{4} + \frac{\mathrm{i}}{4}"), + (0.25 - 0.25j, OutputFormat.Unicode, "1/4 - i/4"), + (2 - 3j, OutputFormat.ASCII, "2 - 3i"), + (2 - 3j, OutputFormat.LaTeX, r"2 - 3\mathrm{i}"), + (2 + 1j, OutputFormat.Unicode, "2 + i"), + (0.75j, OutputFormat.LaTeX, r"\frac{3\mathrm{i}}{4}"), + (math.sqrt(2) / 2 * 1j, OutputFormat.Unicode, "√2i/2"), + (math.sqrt(2) / 2 * 1j, OutputFormat.LaTeX, r"\frac{\sqrt{2}\mathrm{i}}{2}"), + (2 + math.sqrt(2) / 2 * 1j, OutputFormat.Unicode, "2 + √2i/2"), + ], +) +def test_complex_to_str(z: complex, output: OutputFormat, expected: str) -> None: + assert complex_to_str(z, output) == expected + + +def test_complex_to_str_fallback() -> None: + z = 0.123 + 0.456j + assert complex_to_str(z, OutputFormat.ASCII, max_denominator=1) == "0.123 + 0.456i" + assert complex_to_str(z, OutputFormat.Unicode, max_denominator=1) == "0.123 + 0.456i" + assert complex_to_str(z, OutputFormat.LaTeX, max_denominator=1) == "0.123 + 0.456\\mathrm{i}" + + +@pytest.mark.parametrize( + ("angle", "output", "expected"), + [ + (0, OutputFormat.Unicode, "0"), + (0, OutputFormat.ASCII, "0"), + (0, OutputFormat.LaTeX, "0"), + (2, OutputFormat.Unicode, "2π"), + (2, OutputFormat.ASCII, "2pi"), + (2, OutputFormat.LaTeX, r"2\pi"), + (-3, OutputFormat.Unicode, "-3π"), + (0.5, OutputFormat.Unicode, "π/2"), + (0.5, OutputFormat.ASCII, "pi/2"), + (0.5, OutputFormat.LaTeX, r"\frac{\pi}{2}"), + ], +) +def test_angle_to_str_fraction(angle: float, output: OutputFormat, expected: str) -> None: + assert angle_to_str(angle, output) == expected + + +@pytest.mark.parametrize("output", list(OutputFormat)) +def test_angle_to_str_radian_fallback(output: OutputFormat) -> None: + angle = 0.123456789 + assert angle_to_str(angle, output) == f"{angle * math.pi:.2f}" + + +def test_angle_to_str_radian_fallback_max_denominator() -> None: + assert angle_to_str(0.7, OutputFormat.Unicode, max_denominator=3) == f"{0.7 * math.pi:.2f}" + + +def test_validate_output_format() -> None: + with pytest.raises(ValueError, match="unsupported output format"): + complex_to_str(1, object()) # type: ignore[arg-type] diff --git a/tests/test_statevec.py b/tests/test_statevec.py index e7ce8d925..748d0e9f6 100644 --- a/tests/test_statevec.py +++ b/tests/test_statevec.py @@ -8,6 +8,7 @@ from graphix.fundamentals import ANGLE_PI, Plane from graphix.pattern import Pattern +from graphix.pretty_print import OutputFormat, statevec_to_str from graphix.sim.statevec import Statevec, _norm_numeric from graphix.states import BasicStates, PlanarState @@ -239,6 +240,57 @@ def test_to_prob_dict(self, encoding: _ENCODING, dict_ref: Mapping[str, float]) assert np.isclose(0, amp2.imag) +@pytest.mark.parametrize( + ("encoding", "output", "expected"), + [ + ("MSB", OutputFormat.Unicode, "|01⟩"), + ("MSB", OutputFormat.ASCII, "|01>"), + ("MSB", OutputFormat.LaTeX, r"\(\ket{01}\)"), + ("LSB", OutputFormat.Unicode, "|10⟩"), + ], +) +def test_statevec_draw_single_ket(encoding: _ENCODING, output: OutputFormat, expected: str) -> None: + sv = Statevec(data=[BasicStates.ZERO, BasicStates.ONE]) + assert sv.draw(encoding=encoding, output=output) == expected + assert statevec_to_str(sv, output, encoding=encoding) == expected + + +def test_statevec_draw_single_ket_negative_amplitude() -> None: + sv = Statevec(data=np.array([0, -1], dtype=np.complex128)) + assert statevec_to_str(sv, OutputFormat.Unicode) == "-|1⟩" + assert statevec_to_str(sv, OutputFormat.ASCII) == "-|1>" + assert statevec_to_str(sv, OutputFormat.LaTeX) == r"\(-\ket{1}\)" + + +def test_statevec_draw_empty() -> None: + sv = Statevec(data=BasicStates.PLUS) + assert statevec_to_str(sv, OutputFormat.Unicode, atol=1.0) == "0" + assert statevec_to_str(sv, OutputFormat.LaTeX, atol=1.0) == r"\(0\)" + + +def test_statevec_draw_superposition() -> None: + sv = Statevec(data=[BasicStates.ZERO, BasicStates.PLUS, BasicStates.MINUS]) + assert sv.draw(encoding="MSB", output=OutputFormat.Unicode) == "1/2|000⟩ - 1/2|001⟩ + 1/2|010⟩ - 1/2|011⟩" + assert statevec_to_str(sv, OutputFormat.Unicode, encoding="MSB") == ("1/2|000⟩ - 1/2|001⟩ + 1/2|010⟩ - 1/2|011⟩") + assert sv.draw(encoding="MSB", output=OutputFormat.LaTeX) == ( + r"\(\frac{1}{2}\ket{000} - \frac{1}{2}\ket{001} + \frac{1}{2}\ket{010} - \frac{1}{2}\ket{011}\)" + ) + + +def test_statevec_draw_mixed_amplitude() -> None: + amp = 0.25 + 0.25j + other = np.sqrt(1 - abs(amp) ** 2) + sv = Statevec(data=np.array([amp, other], dtype=np.complex128)) + assert statevec_to_str(sv, OutputFormat.Unicode) == "(1/4 + i/4)|0⟩ + 0.935414|1⟩" + assert statevec_to_str(sv, OutputFormat.LaTeX) == ( + r"\(\left(\frac{1}{4} + \frac{\mathrm{i}}{4}\right)\ket{0} + 0.935414\ket{1}\)" + ) + + neg_amp = -0.25 - 0.25j + sv_neg = Statevec(data=np.array([neg_amp, other], dtype=np.complex128)) + assert statevec_to_str(sv_neg, OutputFormat.Unicode) == "(-1/4 - i/4)|0⟩ + 0.935414|1⟩" + + def test_normalize() -> None: statevec = Statevec(nqubit=1, data=BasicStates.PLUS) statevec.remove_qubit(0)