Skip to content

Commit f7b1e70

Browse files
Merge branch 'master' into master
2 parents 48de6d5 + 60bb6c5 commit f7b1e70

3 files changed

Lines changed: 446 additions & 270 deletions

File tree

complexPyTorch/complexFunctions.py

Lines changed: 153 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -6,138 +6,206 @@
66
"""
77

88
import torch
9+
from torch.nn.functional import (
10+
avg_pool2d,
11+
dropout,
12+
dropout2d,
13+
interpolate,
14+
max_pool2d,
15+
relu,
16+
sigmoid,
17+
tanh,
18+
)
19+
920

1021
from torch.nn.functional import max_pool2d, avg_pool2d, dropout, dropout2d, interpolate
1122
from torch import tanh, relu, sigmoid
1223

1324

1425
def complex_matmul(A, B):
15-
'''
16-
Performs the matrix product between two complex matricess
17-
'''
26+
"""
27+
Performs the matrix product between two complex matrices
28+
"""
1829

1930
outp_real = torch.matmul(A.real, B.real) - torch.matmul(A.imag, B.imag)
2031
outp_imag = torch.matmul(A.real, B.imag) + torch.matmul(A.imag, B.real)
21-
32+
2233
return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64)
2334

2435

25-
def complex_avg_pool2d(input, *args, **kwargs):
26-
'''
36+
def complex_avg_pool2d(inp, *args, **kwargs):
37+
"""
2738
Perform complex average pooling.
28-
'''
29-
absolute_value_real = avg_pool2d(input.real, *args, **kwargs)
30-
absolute_value_imag = avg_pool2d(input.imag, *args, **kwargs)
31-
32-
return absolute_value_real.type(torch.complex64)+1j*absolute_value_imag.type(torch.complex64)
39+
"""
40+
absolute_value_real = avg_pool2d(inp.real, *args, **kwargs)
41+
absolute_value_imag = avg_pool2d(inp.imag, *args, **kwargs)
3342

43+
return absolute_value_real.type(torch.complex64) + 1j * absolute_value_imag.type(
44+
torch.complex64
45+
)
3446

35-
def complex_normalize(input):
36-
'''
47+
48+
def complex_normalize(inp):
49+
"""
3750
Perform complex normalization
38-
'''
39-
real_value, imag_value = input.real, input.imag
51+
"""
52+
real_value, imag_value = inp.real, inp.imag
4053
real_norm = (real_value - real_value.mean()) / real_value.std()
4154
imag_norm = (imag_value - imag_value.mean()) / imag_value.std()
42-
43-
return real_norm.type(torch.complex64) + 1j*imag_norm.type(torch.complex64)
55+
return real_norm.type(torch.complex64) + 1j * imag_norm.type(torch.complex64)
4456

4557

46-
def complex_relu(input):
47-
return relu(input.real).type(torch.complex64)+1j*relu(input.imag).type(torch.complex64)
58+
def complex_relu(inp):
59+
return relu(inp.real).type(torch.complex64) + 1j * relu(inp.imag).type(
60+
torch.complex64
61+
)
4862

4963

50-
def complex_sigmoid(input):
51-
return sigmoid(input.real).type(torch.complex64)+1j*sigmoid(input.imag).type(torch.complex64)
64+
def complex_sigmoid(inp):
65+
return sigmoid(inp.real).type(torch.complex64) + 1j * sigmoid(inp.imag).type(
66+
torch.complex64
67+
)
5268

5369

54-
def complex_tanh(input):
55-
return tanh(input.real).type(torch.complex64)+1j*tanh(input.imag).type(torch.complex64)
70+
def complex_tanh(inp):
71+
return tanh(inp.real).type(torch.complex64) + 1j * tanh(inp.imag).type(
72+
torch.complex64
73+
)
5674

5775

58-
def complex_opposite(input):
59-
return -(input.real).type(torch.complex64)+1j*(-(input.imag).type(torch.complex64))
76+
def complex_opposite(inp):
77+
return -inp.real.type(torch.complex64) + 1j * (-inp.imag.type(torch.complex64))
6078

6179

62-
def complex_stack(input, dim):
63-
input_real = [x.real for x in input]
64-
input_imag = [x.imag for x in input]
65-
return torch.stack(input_real, dim).type(torch.complex64)+1j*torch.stack(input_imag, dim).type(torch.complex64)
80+
def complex_stack(inp, dim):
81+
inp_real = [x.real for x in inp]
82+
inp_imag = [x.imag for x in inp]
83+
return torch.stack(inp_real, dim).type(torch.complex64) + 1j * torch.stack(
84+
inp_imag, dim
85+
).type(torch.complex64)
6686

6787

6888
def _retrieve_elements_from_indices(tensor, indices):
6989
flattened_tensor = tensor.flatten(start_dim=-2)
70-
output = flattened_tensor.gather(dim=-1, index=indices.flatten(start_dim=-2)).view_as(indices)
90+
output = flattened_tensor.gather(
91+
dim=-1, index=indices.flatten(start_dim=-2)
92+
).view_as(indices)
7193
return output
7294

7395

74-
def complex_upsample(input, size=None, scale_factor=None, mode='nearest',
75-
align_corners=None, recompute_scale_factor=None):
76-
'''
77-
Performs upsampling by separately interpolating the real and imaginary part and recombining
78-
'''
79-
outp_real = interpolate(input.real, size=size, scale_factor=scale_factor, mode=mode,
80-
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor)
81-
outp_imag = interpolate(input.imag, size=size, scale_factor=scale_factor, mode=mode,
82-
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor)
83-
96+
def complex_upsample(
97+
inp,
98+
size=None,
99+
scale_factor=None,
100+
mode="nearest",
101+
align_corners=None,
102+
recompute_scale_factor=None,
103+
):
104+
"""
105+
Performs upsampling by separately interpolating the real and imaginary part and recombining
106+
"""
107+
outp_real = interpolate(
108+
inp.real,
109+
size=size,
110+
scale_factor=scale_factor,
111+
mode=mode,
112+
align_corners=align_corners,
113+
recompute_scale_factor=recompute_scale_factor,
114+
)
115+
outp_imag = interpolate(
116+
inp.imag,
117+
size=size,
118+
scale_factor=scale_factor,
119+
mode=mode,
120+
align_corners=align_corners,
121+
recompute_scale_factor=recompute_scale_factor,
122+
)
123+
84124
return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64)
85125

86126

87-
def complex_upsample2(input, size=None, scale_factor=None, mode='nearest',
88-
align_corners=None, recompute_scale_factor=None):
89-
'''
90-
Performs upsampling by separately interpolating the amplitude and phase part and recombining
91-
'''
92-
outp_abs = interpolate(input.abs(), size=size, scale_factor=scale_factor, mode=mode,
93-
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor)
94-
angle = torch.atan2(input.imag,input.real)
95-
outp_angle = interpolate(angle, size=size, scale_factor=scale_factor, mode=mode,
96-
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor)
97-
98-
return outp_abs \
99-
* (torch.cos(angle).type(torch.complex64)+1j*torch.sin(angle).type(torch.complex64))
100-
101-
102-
def complex_max_pool2d(input,kernel_size, stride=None, padding=0,
103-
dilation=1, ceil_mode=False, return_indices=False):
104-
'''
127+
def complex_upsample2(
128+
inp,
129+
size=None,
130+
scale_factor=None,
131+
mode="nearest",
132+
align_corners=None,
133+
recompute_scale_factor=None,
134+
):
135+
"""
136+
Performs upsampling by separately interpolating the amplitude and phase part and recombining
137+
"""
138+
outp_abs = interpolate(
139+
inp.abs(),
140+
size=size,
141+
scale_factor=scale_factor,
142+
mode=mode,
143+
align_corners=align_corners,
144+
recompute_scale_factor=recompute_scale_factor,
145+
)
146+
angle = torch.atan2(inp.imag, inp.real)
147+
outp_angle = interpolate(
148+
angle,
149+
size=size,
150+
scale_factor=scale_factor,
151+
mode=mode,
152+
align_corners=align_corners,
153+
recompute_scale_factor=recompute_scale_factor,
154+
)
155+
156+
return outp_abs * (
157+
torch.cos(outp_angle).type(torch.complex64)
158+
+ 1j * torch.sin(outp_angle).type(torch.complex64)
159+
)
160+
161+
162+
def complex_max_pool2d(
163+
inp,
164+
kernel_size,
165+
stride=None,
166+
padding=0,
167+
dilation=1,
168+
ceil_mode=False,
169+
return_indices=False,
170+
):
171+
"""
105172
Perform complex max pooling by selecting on the absolute value on the complex values.
106-
'''
107-
absolute_value, indices = max_pool2d(
108-
input.abs(),
109-
kernel_size = kernel_size,
110-
stride = stride,
111-
padding = padding,
112-
dilation = dilation,
113-
ceil_mode = ceil_mode,
114-
return_indices = True
115-
)
173+
"""
174+
absolute_value, indices = max_pool2d(
175+
inp.abs(),
176+
kernel_size=kernel_size,
177+
stride=stride,
178+
padding=padding,
179+
dilation=dilation,
180+
ceil_mode=ceil_mode,
181+
return_indices=True,
182+
)
116183
# performs the selection on the absolute values
117184
absolute_value = absolute_value.type(torch.complex64)
118-
# retrieve the corresonding phase value using the indices
185+
# retrieve the corresponding phase value using the indices
119186
# unfortunately, the derivative for 'angle' is not implemented
120-
angle = torch.atan2(input.imag,input.real)
187+
angle = torch.atan2(inp.imag, inp.real)
121188
# get only the phase values selected by max pool
122189
angle = _retrieve_elements_from_indices(angle, indices)
123-
return absolute_value \
124-
* (torch.cos(angle).type(torch.complex64)+1j*torch.sin(angle).type(torch.complex64))
190+
return absolute_value * (
191+
torch.cos(angle).type(torch.complex64)
192+
+ 1j * torch.sin(angle).type(torch.complex64)
193+
)
125194

126195

127-
def complex_dropout(input, p=0.5, training=True):
128-
# need to have the same dropout mask for real and imaginary part,
196+
def complex_dropout(inp, p=0.5, training=True):
197+
# need to have the same dropout mask for real and imaginary part,
129198
# this not a clean solution!
130-
#mask = torch.ones_like(input).type(torch.float32)
131-
mask = torch.ones(*input.shape, dtype = torch.float32)
132-
mask = dropout(mask, p, training)*1/(1-p)
133-
mask.type(input.dtype)
134-
return mask*input
199+
mask = torch.ones(*inp.shape, dtype=torch.float32, device=inp.device)
200+
mask = dropout(mask, p, training) * 1 / (1 - p)
201+
mask.type(inp.dtype)
202+
return mask * inp
135203

136204

137-
def complex_dropout2d(input, p=0.5, training=True):
138-
# need to have the same dropout mask for real and imaginary part,
205+
def complex_dropout2d(inp, p=0.5, training=True):
206+
# need to have the same dropout mask for real and imaginary part,
139207
# this not a clean solution!
140-
mask = torch.ones(*input.shape, dtype = torch.float32)
141-
mask = dropout2d(mask, p, training)*1/(1-p)
142-
mask.type(input.dtype)
143-
return mask*input
208+
mask = torch.ones(*inp.shape, dtype=torch.float32, device=inp.device)
209+
mask = dropout2d(mask, p, training) * 1 / (1 - p)
210+
mask.type(inp.dtype)
211+
return mask * inp

0 commit comments

Comments
 (0)