Hi I noticed that you have custom matmul (
|
def complex_matmul(A, B): |
|
''' |
|
Performs the matrix product between two complex matricess |
|
''' |
|
|
|
outp_real = torch.matmul(A.real, B.real) - torch.matmul(A.imag, B.imag) |
|
outp_imag = torch.matmul(A.real, B.imag) + torch.matmul(A.imag, B.real) |
|
|
|
return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64) |
) and tanh, neg functions defined (
|
def complex_tanh(input): |
|
return tanh(input.real).type(torch.complex64)+1j*tanh(input.imag).type(torch.complex64) |
|
|
|
def complex_opposite(input): |
|
return -(input.real).type(torch.complex64)+1j*(-(input.imag).type(torch.complex64)) |
) which are actually unnecessary since these functions are supported for complex numbers in the last couple of releases of PyTorch and would be much faster too (since we call into blas operation for matmul for example).
Hi I noticed that you have custom matmul (
complexPyTorch/complexPyTorch/complexFunctions.py
Lines 11 to 19 in a4e752c
complexPyTorch/complexPyTorch/complexFunctions.py
Lines 52 to 56 in a4e752c