diff --git a/src/deepquantum/utils.py b/src/deepquantum/utils.py index de4c1239..0dfd6871 100644 --- a/src/deepquantum/utils.py +++ b/src/deepquantum/utils.py @@ -44,6 +44,8 @@ def wrapped_function(*args, **kwargs): def apply_complex_fix(fn: Any, tensors_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Apply the function to the tensors in the dictionary and convert the result to complex dtype.""" + if len(tensors_dict) == 0: + return tensors_dict first_tensor = next(iter(tensors_dict.values())) probe = fn(torch.empty(0, dtype=first_tensor.real.dtype, device=first_tensor.device)) target_dtype = dq.dtype_map.get(probe.dtype, probe.dtype) diff --git a/tests/test_circuit.py b/tests/test_circuit.py index 938b9520..0e979bff 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -3,6 +3,17 @@ import deepquantum as dq +def test_qubit_circuit_to_with_barrier(): + cir = dq.QubitCircuit(2) + cir.h(0) + cir.barrier() + cir.rx(1, 0.1) + + cir.to(torch.double) + state = cir() + assert state.dtype == torch.cdouble + + def test_qubit_mps(): nqubit = 3 cir = dq.QubitCircuit(nqubit, mps=True, chi=64)