diff --git a/src/deepquantum/photonic/circuit.py b/src/deepquantum/photonic/circuit.py index f148bf0d..53af7540 100644 --- a/src/deepquantum/photonic/circuit.py +++ b/src/deepquantum/photonic/circuit.py @@ -1168,7 +1168,8 @@ def _get_prob_gaussian_base( sub_mat = sub_gamma else: sub_mat[torch.arange(len(sub_gamma)), torch.arange(len(sub_gamma))] = sub_gamma - haf = abs(hafnian(sub_mat, loop=loop)) ** 2 if purity else hafnian(sub_mat, loop=loop) + temp_haf = hafnian(sub_mat, loop=loop) + haf = temp_haf.real.square() + temp_haf.imag.square() if purity else temp_haf prob = p_vac * haf / product_factorial(final_state).to(haf.device, haf.dtype) elif detector == 'threshold': final_state_double = torch.cat([final_state, final_state]) diff --git a/src/deepquantum/photonic/hafnian_.py b/src/deepquantum/photonic/hafnian_.py index 19b560d6..5b72c79a 100644 --- a/src/deepquantum/photonic/hafnian_.py +++ b/src/deepquantum/photonic/hafnian_.py @@ -49,7 +49,9 @@ def get_submat_haf(a: torch.Tensor, z: torch.Tensor) -> torch.Tensor: return submat -def poly_lambda(submat: torch.Tensor, int_partition: list, power: int, loop: bool = False) -> torch.Tensor: +def poly_lambda( + submat: torch.Tensor, int_partition: list, power: int, loop: bool = False, threshold: float = 1e-30 +) -> torch.Tensor: """Get the coefficient of the polynomial. See https://arxiv.org/abs/1805.12498 Eq.(3.26) (noting that Eq.(3.26) contains a typo) and @@ -85,7 +87,8 @@ def poly_lambda(submat: torch.Tensor, int_partition: list, power: int, loop: boo poly_list = trace_list[orders] / (2 * orders) if loop: poly_list += diag_term[orders - 1] - poly_prod = poly_list.prod() + mask = abs(poly_list) > threshold # numerical stability for gradient + poly_prod = (mask * poly_list).prod() coeff += ncount / factorial(len(orders)) * poly_prod return coeff