|
6 | 6 | """ |
7 | 7 |
|
8 | 8 | import torch |
| 9 | +from torch.nn.functional import ( |
| 10 | + avg_pool2d, |
| 11 | + dropout, |
| 12 | + dropout2d, |
| 13 | + interpolate, |
| 14 | + max_pool2d, |
| 15 | + relu, |
| 16 | + sigmoid, |
| 17 | + tanh, |
| 18 | +) |
| 19 | + |
9 | 20 |
|
10 | 21 | from torch.nn.functional import max_pool2d, avg_pool2d, dropout, dropout2d, interpolate |
11 | 22 | from torch import tanh, relu, sigmoid |
12 | 23 |
|
13 | 24 |
|
14 | 25 | def complex_matmul(A, B): |
15 | | - ''' |
16 | | - Performs the matrix product between two complex matricess |
17 | | - ''' |
| 26 | + """ |
| 27 | + Performs the matrix product between two complex matrices |
| 28 | + """ |
18 | 29 |
|
19 | 30 | outp_real = torch.matmul(A.real, B.real) - torch.matmul(A.imag, B.imag) |
20 | 31 | outp_imag = torch.matmul(A.real, B.imag) + torch.matmul(A.imag, B.real) |
21 | | - |
| 32 | + |
22 | 33 | return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64) |
23 | 34 |
|
24 | 35 |
|
25 | | -def complex_avg_pool2d(input, *args, **kwargs): |
26 | | - ''' |
| 36 | +def complex_avg_pool2d(inp, *args, **kwargs): |
| 37 | + """ |
27 | 38 | Perform complex average pooling. |
28 | | - ''' |
29 | | - absolute_value_real = avg_pool2d(input.real, *args, **kwargs) |
30 | | - absolute_value_imag = avg_pool2d(input.imag, *args, **kwargs) |
31 | | - |
32 | | - return absolute_value_real.type(torch.complex64)+1j*absolute_value_imag.type(torch.complex64) |
| 39 | + """ |
| 40 | + absolute_value_real = avg_pool2d(inp.real, *args, **kwargs) |
| 41 | + absolute_value_imag = avg_pool2d(inp.imag, *args, **kwargs) |
33 | 42 |
|
| 43 | + return absolute_value_real.type(torch.complex64) + 1j * absolute_value_imag.type( |
| 44 | + torch.complex64 |
| 45 | + ) |
34 | 46 |
|
35 | | -def complex_normalize(input): |
36 | | - ''' |
| 47 | + |
| 48 | +def complex_normalize(inp): |
| 49 | + """ |
37 | 50 | Perform complex normalization |
38 | | - ''' |
39 | | - real_value, imag_value = input.real, input.imag |
| 51 | + """ |
| 52 | + real_value, imag_value = inp.real, inp.imag |
40 | 53 | real_norm = (real_value - real_value.mean()) / real_value.std() |
41 | 54 | imag_norm = (imag_value - imag_value.mean()) / imag_value.std() |
42 | | - |
43 | | - return real_norm.type(torch.complex64) + 1j*imag_norm.type(torch.complex64) |
| 55 | + return real_norm.type(torch.complex64) + 1j * imag_norm.type(torch.complex64) |
44 | 56 |
|
45 | 57 |
|
46 | | -def complex_relu(input): |
47 | | - return relu(input.real).type(torch.complex64)+1j*relu(input.imag).type(torch.complex64) |
| 58 | +def complex_relu(inp): |
| 59 | + return relu(inp.real).type(torch.complex64) + 1j * relu(inp.imag).type( |
| 60 | + torch.complex64 |
| 61 | + ) |
48 | 62 |
|
49 | 63 |
|
50 | | -def complex_sigmoid(input): |
51 | | - return sigmoid(input.real).type(torch.complex64)+1j*sigmoid(input.imag).type(torch.complex64) |
| 64 | +def complex_sigmoid(inp): |
| 65 | + return sigmoid(inp.real).type(torch.complex64) + 1j * sigmoid(inp.imag).type( |
| 66 | + torch.complex64 |
| 67 | + ) |
52 | 68 |
|
53 | 69 |
|
54 | | -def complex_tanh(input): |
55 | | - return tanh(input.real).type(torch.complex64)+1j*tanh(input.imag).type(torch.complex64) |
| 70 | +def complex_tanh(inp): |
| 71 | + return tanh(inp.real).type(torch.complex64) + 1j * tanh(inp.imag).type( |
| 72 | + torch.complex64 |
| 73 | + ) |
56 | 74 |
|
57 | 75 |
|
58 | | -def complex_opposite(input): |
59 | | - return -(input.real).type(torch.complex64)+1j*(-(input.imag).type(torch.complex64)) |
| 76 | +def complex_opposite(inp): |
| 77 | + return -inp.real.type(torch.complex64) + 1j * (-inp.imag.type(torch.complex64)) |
60 | 78 |
|
61 | 79 |
|
62 | | -def complex_stack(input, dim): |
63 | | - input_real = [x.real for x in input] |
64 | | - input_imag = [x.imag for x in input] |
65 | | - return torch.stack(input_real, dim).type(torch.complex64)+1j*torch.stack(input_imag, dim).type(torch.complex64) |
| 80 | +def complex_stack(inp, dim): |
| 81 | + inp_real = [x.real for x in inp] |
| 82 | + inp_imag = [x.imag for x in inp] |
| 83 | + return torch.stack(inp_real, dim).type(torch.complex64) + 1j * torch.stack( |
| 84 | + inp_imag, dim |
| 85 | + ).type(torch.complex64) |
66 | 86 |
|
67 | 87 |
|
68 | 88 | def _retrieve_elements_from_indices(tensor, indices): |
69 | 89 | flattened_tensor = tensor.flatten(start_dim=-2) |
70 | | - output = flattened_tensor.gather(dim=-1, index=indices.flatten(start_dim=-2)).view_as(indices) |
| 90 | + output = flattened_tensor.gather( |
| 91 | + dim=-1, index=indices.flatten(start_dim=-2) |
| 92 | + ).view_as(indices) |
71 | 93 | return output |
72 | 94 |
|
73 | 95 |
|
74 | | -def complex_upsample(input, size=None, scale_factor=None, mode='nearest', |
75 | | - align_corners=None, recompute_scale_factor=None): |
76 | | - ''' |
77 | | - Performs upsampling by separately interpolating the real and imaginary part and recombining |
78 | | - ''' |
79 | | - outp_real = interpolate(input.real, size=size, scale_factor=scale_factor, mode=mode, |
80 | | - align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) |
81 | | - outp_imag = interpolate(input.imag, size=size, scale_factor=scale_factor, mode=mode, |
82 | | - align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) |
83 | | - |
| 96 | +def complex_upsample( |
| 97 | + inp, |
| 98 | + size=None, |
| 99 | + scale_factor=None, |
| 100 | + mode="nearest", |
| 101 | + align_corners=None, |
| 102 | + recompute_scale_factor=None, |
| 103 | +): |
| 104 | + """ |
| 105 | + Performs upsampling by separately interpolating the real and imaginary part and recombining |
| 106 | + """ |
| 107 | + outp_real = interpolate( |
| 108 | + inp.real, |
| 109 | + size=size, |
| 110 | + scale_factor=scale_factor, |
| 111 | + mode=mode, |
| 112 | + align_corners=align_corners, |
| 113 | + recompute_scale_factor=recompute_scale_factor, |
| 114 | + ) |
| 115 | + outp_imag = interpolate( |
| 116 | + inp.imag, |
| 117 | + size=size, |
| 118 | + scale_factor=scale_factor, |
| 119 | + mode=mode, |
| 120 | + align_corners=align_corners, |
| 121 | + recompute_scale_factor=recompute_scale_factor, |
| 122 | + ) |
| 123 | + |
84 | 124 | return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64) |
85 | 125 |
|
86 | 126 |
|
87 | | -def complex_upsample2(input, size=None, scale_factor=None, mode='nearest', |
88 | | - align_corners=None, recompute_scale_factor=None): |
89 | | - ''' |
90 | | - Performs upsampling by separately interpolating the amplitude and phase part and recombining |
91 | | - ''' |
92 | | - outp_abs = interpolate(input.abs(), size=size, scale_factor=scale_factor, mode=mode, |
93 | | - align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) |
94 | | - angle = torch.atan2(input.imag,input.real) |
95 | | - outp_angle = interpolate(angle, size=size, scale_factor=scale_factor, mode=mode, |
96 | | - align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) |
97 | | - |
98 | | - return outp_abs \ |
99 | | - * (torch.cos(angle).type(torch.complex64)+1j*torch.sin(angle).type(torch.complex64)) |
100 | | - |
101 | | - |
102 | | -def complex_max_pool2d(input,kernel_size, stride=None, padding=0, |
103 | | - dilation=1, ceil_mode=False, return_indices=False): |
104 | | - ''' |
| 127 | +def complex_upsample2( |
| 128 | + inp, |
| 129 | + size=None, |
| 130 | + scale_factor=None, |
| 131 | + mode="nearest", |
| 132 | + align_corners=None, |
| 133 | + recompute_scale_factor=None, |
| 134 | +): |
| 135 | + """ |
| 136 | + Performs upsampling by separately interpolating the amplitude and phase part and recombining |
| 137 | + """ |
| 138 | + outp_abs = interpolate( |
| 139 | + inp.abs(), |
| 140 | + size=size, |
| 141 | + scale_factor=scale_factor, |
| 142 | + mode=mode, |
| 143 | + align_corners=align_corners, |
| 144 | + recompute_scale_factor=recompute_scale_factor, |
| 145 | + ) |
| 146 | + angle = torch.atan2(inp.imag, inp.real) |
| 147 | + outp_angle = interpolate( |
| 148 | + angle, |
| 149 | + size=size, |
| 150 | + scale_factor=scale_factor, |
| 151 | + mode=mode, |
| 152 | + align_corners=align_corners, |
| 153 | + recompute_scale_factor=recompute_scale_factor, |
| 154 | + ) |
| 155 | + |
| 156 | + return outp_abs * ( |
| 157 | + torch.cos(outp_angle).type(torch.complex64) |
| 158 | + + 1j * torch.sin(outp_angle).type(torch.complex64) |
| 159 | + ) |
| 160 | + |
| 161 | + |
| 162 | +def complex_max_pool2d( |
| 163 | + inp, |
| 164 | + kernel_size, |
| 165 | + stride=None, |
| 166 | + padding=0, |
| 167 | + dilation=1, |
| 168 | + ceil_mode=False, |
| 169 | + return_indices=False, |
| 170 | +): |
| 171 | + """ |
105 | 172 | Perform complex max pooling by selecting on the absolute value on the complex values. |
106 | | - ''' |
107 | | - absolute_value, indices = max_pool2d( |
108 | | - input.abs(), |
109 | | - kernel_size = kernel_size, |
110 | | - stride = stride, |
111 | | - padding = padding, |
112 | | - dilation = dilation, |
113 | | - ceil_mode = ceil_mode, |
114 | | - return_indices = True |
115 | | - ) |
| 173 | + """ |
| 174 | + absolute_value, indices = max_pool2d( |
| 175 | + inp.abs(), |
| 176 | + kernel_size=kernel_size, |
| 177 | + stride=stride, |
| 178 | + padding=padding, |
| 179 | + dilation=dilation, |
| 180 | + ceil_mode=ceil_mode, |
| 181 | + return_indices=True, |
| 182 | + ) |
116 | 183 | # performs the selection on the absolute values |
117 | 184 | absolute_value = absolute_value.type(torch.complex64) |
118 | | - # retrieve the corresonding phase value using the indices |
| 185 | + # retrieve the corresponding phase value using the indices |
119 | 186 | # unfortunately, the derivative for 'angle' is not implemented |
120 | | - angle = torch.atan2(input.imag,input.real) |
| 187 | + angle = torch.atan2(inp.imag, inp.real) |
121 | 188 | # get only the phase values selected by max pool |
122 | 189 | angle = _retrieve_elements_from_indices(angle, indices) |
123 | | - return absolute_value \ |
124 | | - * (torch.cos(angle).type(torch.complex64)+1j*torch.sin(angle).type(torch.complex64)) |
| 190 | + return absolute_value * ( |
| 191 | + torch.cos(angle).type(torch.complex64) |
| 192 | + + 1j * torch.sin(angle).type(torch.complex64) |
| 193 | + ) |
125 | 194 |
|
126 | 195 |
|
127 | | -def complex_dropout(input, p=0.5, training=True): |
128 | | - # need to have the same dropout mask for real and imaginary part, |
| 196 | +def complex_dropout(inp, p=0.5, training=True): |
| 197 | + # need to have the same dropout mask for real and imaginary part, |
129 | 198 | # this not a clean solution! |
130 | | - #mask = torch.ones_like(input).type(torch.float32) |
131 | | - mask = torch.ones(*input.shape, dtype = torch.float32) |
132 | | - mask = dropout(mask, p, training)*1/(1-p) |
133 | | - mask.type(input.dtype) |
134 | | - return mask*input |
| 199 | + mask = torch.ones(*inp.shape, dtype=torch.float32, device=inp.device) |
| 200 | + mask = dropout(mask, p, training) * 1 / (1 - p) |
| 201 | + mask.type(inp.dtype) |
| 202 | + return mask * inp |
135 | 203 |
|
136 | 204 |
|
137 | | -def complex_dropout2d(input, p=0.5, training=True): |
138 | | - # need to have the same dropout mask for real and imaginary part, |
| 205 | +def complex_dropout2d(inp, p=0.5, training=True): |
| 206 | + # need to have the same dropout mask for real and imaginary part, |
139 | 207 | # this not a clean solution! |
140 | | - mask = torch.ones(*input.shape, dtype = torch.float32) |
141 | | - mask = dropout2d(mask, p, training)*1/(1-p) |
142 | | - mask.type(input.dtype) |
143 | | - return mask*input |
| 208 | + mask = torch.ones(*inp.shape, dtype=torch.float32, device=inp.device) |
| 209 | + mask = dropout2d(mask, p, training) * 1 / (1 - p) |
| 210 | + mask.type(inp.dtype) |
| 211 | + return mask * inp |
0 commit comments