Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/deepquantum/photonic/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,8 @@ def _get_prob_gaussian_base(
sub_mat = sub_gamma
else:
sub_mat[torch.arange(len(sub_gamma)), torch.arange(len(sub_gamma))] = sub_gamma
haf = abs(hafnian(sub_mat, loop=loop)) ** 2 if purity else hafnian(sub_mat, loop=loop)
temp_haf = hafnian(sub_mat, loop=loop)
haf = temp_haf.real.square() + temp_haf.imag.square() if purity else temp_haf
prob = p_vac * haf / product_factorial(final_state).to(haf.device, haf.dtype)
elif detector == 'threshold':
final_state_double = torch.cat([final_state, final_state])
Expand Down
7 changes: 5 additions & 2 deletions src/deepquantum/photonic/hafnian_.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def get_submat_haf(a: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
return submat


def poly_lambda(submat: torch.Tensor, int_partition: list, power: int, loop: bool = False) -> torch.Tensor:
def poly_lambda(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不管是对prod的单个元素还是最终乘积做阈值判断都不能正确解决问题(比如阈值以下的元素被乘以0,导致其梯度强制变为0了)。感觉这里最方便的处理还是先把submat变为cdouble,然后return的时候变回原始类型

submat: torch.Tensor, int_partition: list, power: int, loop: bool = False, threshold: float = 1e-30
) -> torch.Tensor:
"""Get the coefficient of the polynomial.

See https://arxiv.org/abs/1805.12498 Eq.(3.26) (noting that Eq.(3.26) contains a typo) and
Expand Down Expand Up @@ -85,7 +87,8 @@ def poly_lambda(submat: torch.Tensor, int_partition: list, power: int, loop: boo
poly_list = trace_list[orders] / (2 * orders)
if loop:
poly_list += diag_term[orders - 1]
poly_prod = poly_list.prod()
mask = abs(poly_list) > threshold # numerical stability for gradient
poly_prod = (mask * poly_list).prod()
coeff += ncount / factorial(len(orders)) * poly_prod
return coeff

Expand Down