diff --git a/graphix/__init__.py b/graphix/__init__.py index 12f8da89c..06eb27775 100644 --- a/graphix/__init__.py +++ b/graphix/__init__.py @@ -19,7 +19,7 @@ from graphix.parameter import Placeholder from graphix.pattern import DrawPatternAnnotations, Pattern from graphix.pauli import Pauli -from graphix.pretty_print import OutputFormat +from graphix.pretty_print import OutputFormat, complex_to_str, densitymatrix_to_str, statevec_to_str from graphix.sim import DensityMatrix, DensityMatrixBackend, Statevec, StatevectorBackend from graphix.space_minimization import SpaceMinimizationHeuristics from graphix.states import BasicStates, PlanarState @@ -68,5 +68,8 @@ "XZCorrections", "__version__", "angle_to_rad", + "complex_to_str", + "densitymatrix_to_str", "rad_to_angle", + "statevec_to_str", ] diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index b2bd72e8f..e778df625 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -10,6 +10,8 @@ from math import pi from typing import TYPE_CHECKING, SupportsFloat +import numpy as np + # `assert_never` introduced in Python 3.11 from typing_extensions import assert_never @@ -22,6 +24,8 @@ from collections.abc import Container, Iterable, Mapping, Sequence from collections.abc import Set as AbstractSet + import numpy.typing as npt + from graphix.command import Node from graphix.flow.core import PauliFlow, XZCorrections from graphix.fundamentals import Angle @@ -439,3 +443,456 @@ def xzcorr_to_str(xzcorr: XZCorrections[AbstractMeasurement], output: OutputForm partial_order_to_str(xzcorr.partial_order_layers, output), ) ) + + +# --------------------------------------------------------------------------- +# Pretty-printing for quantum states (statevectors and density matrices) +# --------------------------------------------------------------------------- + + +_EPS = 1e-10 + + +def _format_imag_unit(output: OutputFormat) -> str: + """Return the imaginary unit formatted for the given output format.""" + if output == OutputFormat.LaTeX: + return r"\mathrm{i}" + return "i" + + +def _format_ket(basis: str, output: OutputFormat) -> str: + """Format a basis bitstring as a ket. + + Parameters + ---------- + basis : str + Bitstring representation of a basis state (e.g. ``"01"``). + output : OutputFormat + Desired formatting style. + + Returns + ------- + str + The formatted ket. + """ + if output == OutputFormat.LaTeX: + return rf"\lvert {basis}\rangle" + if output == OutputFormat.Unicode: + return f"|{basis}⟩" + return f"|{basis}>" + + +def _format_bra(basis: str, output: OutputFormat) -> str: + """Format a basis bitstring as a bra. + + Parameters + ---------- + basis : str + Bitstring representation of a basis state (e.g. ``"01"``). + output : OutputFormat + Desired formatting style. + + Returns + ------- + str + The formatted bra. + """ + if output == OutputFormat.LaTeX: + return rf"\langle {basis}\rvert" + if output == OutputFormat.Unicode: + return f"⟨{basis}|" + return f"<{basis}|" + + +def _format_frac(num: int, den: int, output: OutputFormat) -> str: + """Format a fraction *num* / *den*. + + Parameters + ---------- + num : int + Numerator. + den : int + Denominator. + output : OutputFormat + Desired formatting style. + + Returns + ------- + str + """ + if den == 1: + return str(num) + + if output == OutputFormat.LaTeX: + return rf"\frac{{{num}}}{{{den}}}" + return f"{num}/{den}" + + +def _format_scalar(val: float, output: OutputFormat, max_denominator: int = 1000) -> str: + """Format a real scalar, detecting integers and simple fractions. + + Parameters + ---------- + val : float + Value to format. + output : OutputFormat + Desired formatting style. + max_denominator : int, optional + Maximum denominator for fraction detection (default: 1000). + + Returns + ------- + str + """ + if abs(val) < _EPS: + return "0" + + # Integer + if math.isclose(val, round(val), abs_tol=_EPS): + return str(round(val)) + + # Simple fraction + frac = Fraction(val).limit_denominator(max_denominator) + if math.isclose(val, float(frac), abs_tol=_EPS): + return _format_frac(frac.numerator, frac.denominator, output) + + # Square-root of a simple fraction: |val| ≈ sqrt(n/d) + val_sq = val * val + frac_sq = Fraction(val_sq).limit_denominator(max_denominator) + if math.isclose(val_sq, float(frac_sq), abs_tol=_EPS) and frac_sq != frac: + result = _format_sqrt_frac(frac_sq, output) + return f"-{result}" if val < 0 else result + + # Fallback: decimal + return f"{val:.6g}" + + +def _format_sqrt_frac(frac: Fraction, output: OutputFormat) -> str: + """Format the square root of a positive fraction. + + Parameters + ---------- + frac : Fraction + A positive fraction. The formatted string represents ``sqrt(frac)``. + output : OutputFormat + Desired formatting style. + + Returns + ------- + str + """ + num = frac.numerator + den = frac.denominator + + sqrt_sym: str + if output == OutputFormat.LaTeX: + sqrt_sym = r"\sqrt" + elif output == OutputFormat.Unicode: + sqrt_sym = "√" + else: + sqrt_sym = "sqrt" + + if den == 1: + if output == OutputFormat.LaTeX: + return rf"{sqrt_sym}{{{num}}}" + return f"{sqrt_sym}{num}" if output == OutputFormat.Unicode else f"{sqrt_sym}({num})" + + # Try to simplify by pulling perfect squares out of the denominator + den_sqrt = math.isqrt(den) + if den_sqrt * den_sqrt == den: + # den is a perfect square: √(num/den) = √num / den_sqrt + def _sqrt_num(n: int) -> str: + if output == OutputFormat.LaTeX: + return rf"{sqrt_sym}{{{n}}}" + return f"{sqrt_sym}{n}" if output == OutputFormat.Unicode else f"{sqrt_sym}({n})" + + if num == 1: + # √(1/den) = 1/den_sqrt + if output == OutputFormat.LaTeX: + return rf"\frac{{1}}{{{den_sqrt}}}" + return f"1/{den_sqrt}" + sqrt_top = _sqrt_num(num) + if output == OutputFormat.LaTeX: + return rf"\frac{{{sqrt_top}}}{{{den_sqrt}}}" + return f"{sqrt_top}/{den_sqrt}" + + if num == 1: + # 1/√(den) + if output == OutputFormat.LaTeX: + return rf"\frac{{1}}{{{sqrt_sym}{{{den}}}}}" + inner = f"{sqrt_sym}{den}" if output == OutputFormat.Unicode else f"{sqrt_sym}({den})" + return f"1/{inner}" + + # General: √(num/den) + if output == OutputFormat.LaTeX: + return rf"{sqrt_sym}{{\frac{{{num}}}{{{den}}}}}" + return f"{sqrt_sym}({num}/{den})" + + +def complex_to_str( + z: complex | np.complex128 | np.object_ | float, + output: OutputFormat, + max_denominator: int = 1000, +) -> str: + r"""Pretty-print a complex number. + + Detects common values and renders them in a human-readable form: + + - *zero* → ``"0"`` + - *simple fractions* → ``"1/2"``, ``"3/4"``, etc. + - *square roots of fractions* → ``"1/√2"``, ``"√3/2"``, etc. + - *pure exponentials* (when ``|z| ≈ 1``) → ``"e^{iπ/3}"``, ``"e^{-iπ/2}"``, etc. + - *fallback* → decimal notation. + + Parameters + ---------- + z : complex or np.complex128 or np.object_ or float + The complex number to format. + output : OutputFormat + Desired formatting style (ASCII, LaTeX or Unicode). + max_denominator : int, optional + Maximum denominator when detecting simple fractions (default: 1000). + + Returns + ------- + str + """ + if not isinstance(z, (complex, float, int, np.complexfloating, np.floating)): + return str(z) + + if abs(z) < _EPS: + return "0" + + imag_unit = _format_imag_unit(output) + re, im = z.real, z.imag + pure_real = abs(im) < _EPS + pure_imag = abs(re) < _EPS + + # Pure-real numbers: format as a scalar + if pure_real: + return _format_scalar(re, output, max_denominator) + + # Pure-imaginary numbers: format as scalar · i + if pure_imag: + imag_str = _format_scalar(abs(im), output, max_denominator) + if imag_str == "1": + return imag_unit if im > 0 else f"-{imag_unit}" + return f"{imag_str}{imag_unit}" if im > 0 else f"-{imag_str}{imag_unit}" + + # Exponential: |z| ≈ 1 and both components non-zero + if abs(abs(z) - 1.0) < _EPS: + angle = math.atan2(im, re) + angle_in_pi = angle / math.pi + frac = Fraction(angle_in_pi).limit_denominator(max_denominator) + if math.isclose(angle_in_pi, float(frac), abs_tol=_EPS): + return _format_exponential(float(frac), output, max_denominator) + + # General complex: a ± b·i + real_str = _format_scalar(re, output, max_denominator) + imag_str = _format_scalar(abs(im), output, max_denominator) + imag_str = imag_unit if imag_str == "1" else f"{imag_str}{imag_unit}" + + if im >= 0: + return f"{real_str}+{imag_str}" + return f"{real_str}-{imag_str}" + + +def _format_exponential(angle_in_pi: float, output: OutputFormat, max_denominator: int) -> str: + """Format a pure phase as ``e^{iθ}`` where *θ* is a fraction of π. + + Parameters + ---------- + angle_in_pi : float + The phase angle in units of π. + output : OutputFormat + Desired formatting style. + max_denominator : int + Maximum denominator for fraction detection. + + Returns + ------- + str + """ + imag_unit = _format_imag_unit(output) + + if abs(angle_in_pi) < _EPS: + return "1" + + frac = Fraction(angle_in_pi).limit_denominator(max_denominator) + num = frac.numerator + den = frac.denominator + + if num < 0: + sign_prefix = "-" + num = -num + else: + sign_prefix = "" + + # Build the body in i·num·π/den format + if output == OutputFormat.LaTeX: + if num == 1 and den == 1: + body = f"{imag_unit}\\pi" + elif num == 1: + body = f"{imag_unit}\\pi/{den}" + elif den == 1: + body = f"{imag_unit}{num}\\pi" + else: + body = f"{imag_unit}{num}\\pi/{den}" + elif output == OutputFormat.Unicode: + if num == 1 and den == 1: + body = f"{imag_unit}π" + elif num == 1: + body = f"{imag_unit}π/{den}" + elif den == 1: + body = f"{imag_unit}{num}π" + else: + body = f"{imag_unit}{num}π/{den}" + elif num == 1 and den == 1: + body = f"{imag_unit}*pi" + elif num == 1: + body = f"{imag_unit}*pi/{den}" + elif den == 1: + body = f"{imag_unit}*{num}*pi" + else: + body = f"{imag_unit}*{num}*pi/{den}" + + body = f"{sign_prefix}{body}" + + if output == OutputFormat.LaTeX: + return rf"\mathrm{{e}}^{{{body}}}" + if output == OutputFormat.Unicode: + return f"e^({body})" + return f"exp({body})" + + +def _basis_label(index: int, nqubit: int) -> str: + """Return the bitstring label for a basis index. + + Parameters + ---------- + index : int + Basis index (integer). + nqubit : int + Number of qubits. + + Returns + ------- + str + Bitstring of length *nqubit*. + """ + return f"{index:0{nqubit}b}" + + +def statevec_to_str( + sv_dict: Mapping[str, np.object_ | np.complex128], + output: OutputFormat, + max_denominator: int = 1000, +) -> str: + """Pretty-print a statevector from its dictionary representation. + + Parameters + ---------- + sv_dict : Mapping[str, complex] + Statevector dictionary as returned by :meth:`graphix.sim.statevec.Statevec.to_dict`. + output : OutputFormat + Desired formatting style (ASCII, LaTeX or Unicode). + max_denominator : int, optional + Maximum denominator for fraction detection (default: 1000). + + Returns + ------- + str + """ + if not sv_dict: + return "0" + + parts: list[str] = [] + for basis, amplitude in sv_dict.items(): + amp_str = complex_to_str(complex(amplitude), output, max_denominator) + ket = _format_ket(basis, output) + + if amp_str == "1": + term = ket + elif amp_str == "-1": + term = f"-{ket}" + else: + term = f"{amp_str}{ket}" + + if not parts: + parts.append(term) + elif term.startswith("-"): + parts.append(f" - {term[1:]}") + else: + parts.append(f" + {term}") + + return "".join(parts) + + +def densitymatrix_to_str( + rho: npt.NDArray[np.object_ | np.complex128], + nqubit: int, + output: OutputFormat, + *, + max_denominator: int = 1000, + cutoff: float = 1e-10, +) -> str: + r"""Pretty-print a density matrix using Dirac notation. + + Extracts non-zero elements and formats them as a sum of weighted projectors: + + .. math:: + + \\rho = \\sum_{i,j} \\rho_{ij} \\lvert i \\rangle\\langle j \\rvert + + Parameters + ---------- + rho : Matrix + The density matrix as a ``2**nqubit × 2**nqubit`` array. + nqubit : int + Number of qubits. + output : OutputFormat + Desired formatting style (ASCII, LaTeX or Unicode). + max_denominator : int, optional + Maximum denominator for fraction detection (default: 1000). + cutoff : float, optional + Tolerance below which matrix elements are treated as zero (default: ``1e-10``). + + Returns + ------- + str + """ + n = rho.shape[0] + + terms: list[tuple[complex, str, str]] = [] + for i in range(n): + for j in range(n): + val: complex = complex(rho[i, j]) + if abs(val) < cutoff: + continue + val_str = complex_to_str(val, output, max_denominator) + ket = _format_ket(_basis_label(i, nqubit), output) + bra = _format_bra(_basis_label(j, nqubit), output) + # |i⟩⟨j| + dirac = f"{ket}{bra}" + terms.append((val, val_str, dirac)) + + if not terms: + return "0" + + result_parts: list[str] = [] + for i, (_val, val_str, dirac) in enumerate(terms): + if val_str == "1": + term = dirac + elif val_str == "-1": + term = f"-{dirac}" + else: + term = f"{val_str}{dirac}" + + if i == 0: + result_parts.append(term) + elif term.startswith("-"): + result_parts.append(f" - {term[1:]}") + else: + result_parts.append(f" + {term}") + + return "".join(result_parts) diff --git a/graphix/sim/density_matrix.py b/graphix/sim/density_matrix.py index cb5570e2a..fc67b6c52 100644 --- a/graphix/sim/density_matrix.py +++ b/graphix/sim/density_matrix.py @@ -19,6 +19,7 @@ from graphix import parameter from graphix.channels import KrausChannel from graphix.parameter import Expression, ExpressionOrFloat, ExpressionOrSupportsComplex +from graphix.pretty_print import OutputFormat, densitymatrix_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, matmul, outer, tensordot, vdot from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec from graphix.states import BasicStates, State @@ -354,6 +355,26 @@ def flatten(self) -> Matrix: """Return flattened density matrix.""" return self.rho.flatten() + def draw( + self, + output: OutputFormat = OutputFormat.Unicode, + *, + max_denominator: int = 1000, + cutoff: float = 1e-10, + ) -> None: + """Pretty-print the density matrix. + + Parameters + ---------- + output : OutputFormat, default=OutputFormat.Unicode + Desired formatting style. + max_denominator : int, default=1000 + Maximum denominator when detecting simple fractions. + cutoff : float, default=1e-10 + Tolerance below which matrix elements are treated as zero. + """ + print(densitymatrix_to_str(self.rho, self.nqubit, output, max_denominator=max_denominator, cutoff=cutoff)) + def apply_channel(self, channel: KrausChannel, qargs: Sequence[int]) -> None: """Apply a channel to a density matrix. diff --git a/graphix/sim/statevec.py b/graphix/sim/statevec.py index 1f7dd23e6..2cb30ede8 100644 --- a/graphix/sim/statevec.py +++ b/graphix/sim/statevec.py @@ -16,6 +16,7 @@ from graphix import parameter, states from graphix.parameter import Expression, ExpressionOrSupportsComplex, check_expression_or_float +from graphix.pretty_print import OutputFormat, statevec_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, tensordot from graphix.states import BasicStates @@ -486,6 +487,33 @@ def to_dict( """ return self._to_dict_map(lambda x: x, encoding, rtol=rtol, atol=atol) + def draw( + self, + output: OutputFormat = OutputFormat.Unicode, + encoding: _ENCODING = "MSB", + *, + rtol: float = 0.0, + atol: float = 1e-8, + max_denominator: int = 1000, + ) -> None: + """Pretty-print the statevector. + + Parameters + ---------- + output : OutputFormat, default=OutputFormat.Unicode + Desired formatting style. + encoding : Literal["LSB", "MSB"], default="MSB" + Encoding for the basis kets. See :meth:`to_dict` for additional information. + rtol : float, default=0.0 + Relative tolerance for filtering zero amplitudes. + atol : float, default=1e-8 + Absolute tolerance for filtering zero amplitudes. + max_denominator : int, default=1000 + Maximum denominator when detecting simple fractions. + """ + sv_dict = self.to_dict(encoding=encoding, rtol=rtol, atol=atol) + print(statevec_to_str(sv_dict, output, max_denominator=max_denominator)) + def to_prob_dict( self, encoding: _ENCODING = "MSB", *, rtol: float = 0.0, atol: float = 1e-8 ) -> dict[str, np.object_ | np.float64]: diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 28dedce55..ab8e31fef 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -1,8 +1,10 @@ from __future__ import annotations +import math from typing import TYPE_CHECKING import networkx as nx +import numpy as np import pytest from numpy.random import PCG64, Generator @@ -13,8 +15,11 @@ from graphix.opengraph import OpenGraph from graphix.parameter import Placeholder from graphix.pattern import Pattern -from graphix.pretty_print import OutputFormat, pattern_to_str +from graphix.pretty_print import OutputFormat, complex_to_str, densitymatrix_to_str, pattern_to_str, statevec_to_str from graphix.random_objects import rand_circuit +from graphix.sim.density_matrix import DensityMatrix +from graphix.sim.statevec import Statevec +from graphix.states import BasicStates from graphix.transpiler import Circuit if TYPE_CHECKING: @@ -202,3 +207,127 @@ def test_xzcorr_str() -> None: str(flow) == "x(3) = {5}, x(4) = {6}, x(1) = {3}, x(2) = {4}; z(1) = {4, 5}, z(2) = {3, 6}; {1, 2} < {3, 4} < {5, 6}" ) + + +class TestComplexToStr: + """Tests for :func:`~graphix.pretty_print.complex_to_str`.""" + + def test_zero(self) -> None: + assert complex_to_str(0j, OutputFormat.Unicode) == "0" + assert complex_to_str(0j, OutputFormat.ASCII) == "0" + assert complex_to_str(0j, OutputFormat.LaTeX) == "0" + + def test_integer(self) -> None: + assert complex_to_str(1 + 0j, OutputFormat.Unicode) == "1" + assert complex_to_str(-1 + 0j, OutputFormat.Unicode) == "-1" + assert complex_to_str(2 + 0j, OutputFormat.Unicode) == "2" + + def test_simple_fractions(self) -> None: + assert complex_to_str(0.5 + 0j, OutputFormat.Unicode) == "1/2" + assert complex_to_str(0.25 + 0j, OutputFormat.Unicode) == "1/4" + assert complex_to_str(0.75 + 0j, OutputFormat.Unicode) == "3/4" + assert complex_to_str(-0.5 + 0j, OutputFormat.Unicode) == "-1/2" + + def test_square_roots(self) -> None: + sqrt2_inv = 1 / math.sqrt(2) + sqrt3_2 = math.sqrt(3) / 2 + assert complex_to_str(sqrt2_inv, OutputFormat.Unicode) == "1/√2", f"got {complex_to_str(sqrt2_inv, OutputFormat.Unicode)!r}" + assert complex_to_str(sqrt3_2, OutputFormat.Unicode) == "√3/2" + assert complex_to_str(-sqrt2_inv, OutputFormat.Unicode) == "-1/√2" + + def test_exponential(self) -> None: + e_ipi_3 = 0.5 + 0.8660254037844386j + e_ipi_4 = math.sqrt(2) / 2 + 1j * math.sqrt(2) / 2 + e_i2pi_3 = -0.5 + 0.8660254037844386j + assert complex_to_str(e_ipi_3, OutputFormat.Unicode) == "e^(iπ/3)", f"got {complex_to_str(e_ipi_3, OutputFormat.Unicode)!r}" + assert complex_to_str(e_ipi_4, OutputFormat.Unicode) == "e^(iπ/4)", f"got {complex_to_str(e_ipi_4, OutputFormat.Unicode)!r}" + assert complex_to_str(e_i2pi_3, OutputFormat.Unicode) == "e^(i2π/3)", f"got {complex_to_str(e_i2pi_3, OutputFormat.Unicode)!r}" + + def test_pure_imaginary(self) -> None: + assert complex_to_str(1j, OutputFormat.Unicode) == "i" + assert complex_to_str(-1j, OutputFormat.Unicode) == "-i" + assert complex_to_str(0.5j, OutputFormat.Unicode) == "1/2i" + assert complex_to_str(-0.5j, OutputFormat.Unicode) == "-1/2i" + + def test_general_complex(self) -> None: + assert complex_to_str(1 + 1j, OutputFormat.Unicode) == "1+i" + assert complex_to_str(1 - 1j, OutputFormat.Unicode) == "1-i" + assert complex_to_str(-1 + 1j, OutputFormat.Unicode) == "-1+i" + assert complex_to_str(0.5 + 0.5j, OutputFormat.Unicode) == "1/2+1/2i" + + def test_ascii_format(self) -> None: + assert complex_to_str(1j, OutputFormat.ASCII) == "i" + assert complex_to_str(0.7071067811865474 + 0j, OutputFormat.ASCII) == "1/sqrt(2)" + assert complex_to_str(0.5 + 0.8660254037844386j, OutputFormat.ASCII) == "exp(i*pi/3)" + + def test_latex_format(self) -> None: + assert complex_to_str(1j, OutputFormat.LaTeX) == r"\mathrm{i}" + assert complex_to_str(0.5 + 0j, OutputFormat.LaTeX) == r"\frac{1}{2}" + assert complex_to_str(0.7071067811865474 + 0j, OutputFormat.LaTeX) == r"\frac{1}{\sqrt{2}}" + assert complex_to_str(0.5 + 0.8660254037844386j, OutputFormat.LaTeX) == r"\mathrm{e}^{\mathrm{i}\pi/3}" + + +class TestStatevecToStr: + """Tests for :func:`~graphix.pretty_print.statevec_to_str`.""" + + def test_single_basis_state(self) -> None: + sv = Statevec(data=[BasicStates.ZERO]) + d = sv.to_dict() + assert statevec_to_str(d, OutputFormat.Unicode) == "|0⟩" + + def test_two_qubit_product(self) -> None: + sv = Statevec(data=[BasicStates.ZERO, BasicStates.ONE]) + d = sv.to_dict() + result = statevec_to_str(d, OutputFormat.Unicode) + assert result == "|01⟩" + + def test_plus_state(self) -> None: + sv = Statevec(data=[BasicStates.PLUS]) + d = sv.to_dict() + result = statevec_to_str(d, OutputFormat.Unicode) + assert result == "1/√2|0⟩ + 1/√2|1⟩" + + def test_latex_statevec(self) -> None: + sv = Statevec(data=[BasicStates.PLUS]) + d = sv.to_dict() + result = statevec_to_str(d, OutputFormat.LaTeX) + assert result == r"\frac{1}{\sqrt{2}}\lvert 0\rangle + \frac{1}{\sqrt{2}}\lvert 1\rangle" + + def test_empty_dict(self) -> None: + assert statevec_to_str({}, OutputFormat.Unicode) == "0" + + +class TestDensityMatrixToStr: + """Tests for :func:`~graphix.pretty_print.densitymatrix_to_str`.""" + + def test_pure_state_zero(self) -> None: + dm = DensityMatrix(data=[BasicStates.ZERO]) + result = densitymatrix_to_str(dm.rho, dm.nqubit, OutputFormat.Unicode) + assert result == "|0⟩⟨0|" + + def test_mixed_state(self) -> None: + """Check a 50/50 mixed state returns a sum of projectors.""" + rho = np.eye(2, dtype=np.complex128) / 2 + result = densitymatrix_to_str(rho, 1, OutputFormat.Unicode) + assert "1/2" in result + assert result.count("⟩⟨") == 2 + + +class TestStatevecDraw: + """Tests for :meth:`graphix.sim.statevec.Statevec.draw`.""" + + def test_draw(self, capsys: pytest.CaptureFixture[str]) -> None: + sv = Statevec(data=[BasicStates.PLUS]) + sv.draw() + captured = capsys.readouterr() + assert "1/√2|0⟩ + 1/√2|1⟩" in captured.out + + +class TestDensityMatrixDraw: + """Tests for :meth:`graphix.sim.density_matrix.DensityMatrix.draw`.""" + + def test_draw(self, capsys: pytest.CaptureFixture[str]) -> None: + dm = DensityMatrix(data=[BasicStates.ZERO]) + dm.draw() + captured = capsys.readouterr() + assert "|0⟩⟨0|" in captured.out