Skip to content

Commit 2044cb0

Browse files
Merge pull request #21 from egrinstein/master
Fix conflicts, add RNN modules, PReLU and formatting
2 parents 60bb6c5 + f7b1e70 commit 2044cb0

2 files changed

Lines changed: 191 additions & 68 deletions

File tree

complexPyTorch/complexFunctions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
)
1919

2020

21+
from torch.nn.functional import max_pool2d, avg_pool2d, dropout, dropout2d, interpolate
22+
from torch import tanh, relu, sigmoid
23+
24+
2125
def complex_matmul(A, B):
2226
"""
2327
Performs the matrix product between two complex matrices
@@ -48,7 +52,6 @@ def complex_normalize(inp):
4852
real_value, imag_value = inp.real, inp.imag
4953
real_norm = (real_value - real_value.mean()) / real_value.std()
5054
imag_norm = (imag_value - imag_value.mean()) / imag_value.std()
51-
5255
return real_norm.type(torch.complex64) + 1j * imag_norm.type(torch.complex64)
5356

5457

complexPyTorch/complexLayers.py

Lines changed: 187 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,43 @@
1111
from typing import Optional
1212

1313
import torch
14-
from complexFunctions import (complex_avg_pool2d, complex_dropout,
15-
complex_dropout2d, complex_max_pool2d,
16-
complex_opposite, complex_relu, complex_sigmoid,
17-
complex_tanh)
18-
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, ConvTranspose2d,
19-
Linear, Module, Parameter, init)
14+
from torch.nn import (
15+
Module, Parameter, init,
16+
Conv2d, ConvTranspose2d, Linear, LSTM, GRU,
17+
BatchNorm1d, BatchNorm2d,
18+
PReLU
19+
)
20+
21+
from .complexFunctions import (
22+
complex_relu,
23+
complex_tanh,
24+
complex_sigmoid,
25+
complex_max_pool2d,
26+
complex_avg_pool2d,
27+
complex_dropout,
28+
complex_dropout2d,
29+
complex_opposite,
30+
)
31+
32+
33+
def apply_complex(fr, fi, input, dtype=torch.complex64):
34+
return (fr(input.real)-fi(input.imag)).type(dtype) \
35+
+ 1j*(fr(input.imag)+fi(input.real)).type(dtype)
2036

2137

22-
def apply_complex(fr, fi, inp, dtype=torch.complex64):
23-
return (fr(inp.real) - fi(inp.imag)).type(dtype) + 1j * (
24-
fr(inp.imag) + fi(inp.real)
25-
).type(dtype)
26-
2738
class ComplexDropout(Module):
2839
def __init__(self, p=0.5):
29-
super(ComplexDropout, self).__init__()
40+
super().__init__()
3041
self.p = p
3142

32-
def forward(self, inp):
43+
def forward(self, input):
3344
if self.training:
34-
return complex_dropout(inp, self.p)
45+
return complex_dropout(input, self.p)
3546
else:
3647
return inp
3748

3849

50+
3951
class ComplexDropout2d(Module):
4052
def __init__(self, p=0.5):
4153
super(ComplexDropout2d, self).__init__()
@@ -89,8 +101,8 @@ def __init__(self,kernel_size, stride= None, padding = 0,
89101
self.count_include_pad = count_include_pad
90102
self.divisor_override = divisor_override
91103

92-
def forward(self,input):
93-
return complex_avg_pool2d(input,kernel_size = self.kernel_size,
104+
def forward(self,inp):
105+
return complex_avg_pool2d(inp,kernel_size = self.kernel_size,
94106
stride = self.stride, padding = self.padding,
95107
ceil_mode = self.ceil_mode, count_include_pad = self.count_include_pad,
96108
divisor_override = self.divisor_override)
@@ -106,6 +118,16 @@ class ComplexSigmoid(Module):
106118
@staticmethod
107119
def forward(inp):
108120
return complex_sigmoid(inp)
121+
122+
class ComplexPReLU(Module):
123+
def __init__(self):
124+
super().__init__()
125+
self.r_prelu = PReLU()
126+
self.i_prelu = PReLU()
127+
128+
@staticmethod
129+
def forward(self, inp):
130+
return self.r_prelu(inp.real) + 1j*self.i_prelu(inp.imag)
109131

110132

111133
class ComplexTanh(Module):
@@ -129,32 +151,12 @@ def __init__(
129151
padding_mode="zeros",
130152
):
131153

132-
super(ComplexConvTranspose2d, self).__init__()
154+
super().__init__()
133155

134-
self.conv_tran_r = ConvTranspose2d(
135-
in_channels,
136-
out_channels,
137-
kernel_size,
138-
stride,
139-
padding,
140-
output_padding,
141-
groups,
142-
bias,
143-
dilation,
144-
padding_mode,
145-
)
146-
self.conv_tran_i = ConvTranspose2d(
147-
in_channels,
148-
out_channels,
149-
kernel_size,
150-
stride,
151-
padding,
152-
output_padding,
153-
groups,
154-
bias,
155-
dilation,
156-
padding_mode,
157-
)
156+
self.conv_tran_r = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding,
157+
output_padding, groups, bias, dilation, padding_mode)
158+
self.conv_tran_i = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding,
159+
output_padding, groups, bias, dilation, padding_mode)
158160

159161
def forward(self, inp):
160162
return apply_complex(self.conv_tran_r, self.conv_tran_i, inp)
@@ -200,7 +202,7 @@ def forward(self, inp):
200202

201203
class ComplexLinear(Module):
202204
def __init__(self, in_features, out_features):
203-
super(ComplexLinear, self).__init__()
205+
super().__init__()
204206
self.fc_r = Linear(in_features, out_features)
205207
self.fc_i = Linear(in_features, out_features)
206208

@@ -315,7 +317,7 @@ def reset_parameters(self):
315317
init.constant_(self.weight[:, :2], 1.4142135623730951)
316318
init.zeros_(self.weight[:, 2])
317319
init.zeros_(self.bias)
318-
320+
319321

320322
class ComplexBatchNorm2d(_ComplexBatchNorm):
321323
def forward(self, inp):
@@ -325,7 +327,8 @@ def forward(self, inp):
325327
if self.num_batches_tracked is not None:
326328
self.num_batches_tracked += 1
327329
if self.momentum is None: # use cumulative moving average
328-
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
330+
exponential_average_factor = 1.0 / \
331+
float(self.num_batches_tracked)
329332
else: # use exponential moving average
330333
exponential_average_factor = self.momentum
331334

@@ -405,7 +408,6 @@ def forward(self, inp):
405408
).type(
406409
torch.complex64
407410
)
408-
409411
return inp
410412

411413

@@ -418,7 +420,8 @@ def forward(self, inp):
418420
if self.num_batches_tracked is not None:
419421
self.num_batches_tracked += 1
420422
if self.momentum is None: # use cumulative moving average
421-
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
423+
exponential_average_factor = 1.0 / \
424+
float(self.num_batches_tracked)
422425
else: # use exponential moving average
423426
exponential_average_factor = self.momentum
424427

@@ -500,26 +503,32 @@ def forward(self, inp):
500503
return inp
501504

502505

506+
503507
class ComplexGRUCell(Module):
504508
"""
505509
A GRU cell for complex-valued inputs
506510
"""
507-
508-
def __init__(self, inp_length=10, hidden_length=20):
509-
super(ComplexGRUCell, self).__init__()
510-
self.inp_length = inp_length
511+
def __init__(self, input_length, hidden_length):
512+
super().__init__()
513+
self.input_length = input_length
511514
self.hidden_length = hidden_length
512515

513516
# reset gate components
514-
self.linear_reset_w1 = ComplexLinear(self.inp_length, self.hidden_length)
515-
self.linear_reset_r1 = ComplexLinear(self.hidden_length, self.hidden_length)
517+
self.linear_reset_w1 = ComplexLinear(
518+
self.input_length, self.hidden_length)
519+
self.linear_reset_r1 = ComplexLinear(
520+
self.hidden_length, self.hidden_length)
516521

517-
self.linear_reset_w2 = ComplexLinear(self.inp_length, self.hidden_length)
518-
self.linear_reset_r2 = ComplexLinear(self.hidden_length, self.hidden_length)
522+
self.linear_reset_w2 = ComplexLinear(
523+
self.input_length, self.hidden_length)
524+
self.linear_reset_r2 = ComplexLinear(
525+
self.hidden_length, self.hidden_length)
519526

520527
# update gate components
521-
self.linear_gate_w3 = ComplexLinear(self.inp_length, self.hidden_length)
522-
self.linear_gate_r3 = ComplexLinear(self.hidden_length, self.hidden_length)
528+
self.linear_gate_w3 = ComplexLinear(
529+
self.input_length, self.hidden_length)
530+
self.linear_gate_r3 = ComplexLinear(
531+
self.hidden_length, self.hidden_length)
523532

524533
self.activation_gate = ComplexSigmoid()
525534
self.activation_candidate = ComplexTanh()
@@ -555,30 +564,35 @@ def forward(self, x, h):
555564

556565
# Equation 4: the new hidden state
557566
h_new = (1 + complex_opposite(z)) * n + z * h # element-wise multiplication
558-
559567
return h_new
560568

561569

562570
class ComplexBNGRUCell(Module):
563571
"""
564572
A BN-GRU cell for complex-valued inputs
565573
"""
566-
567-
def __init__(self, inp_length=10, hidden_length=20):
568-
super(ComplexBNGRUCell, self).__init__()
569-
self.inp_length = inp_length
574+
575+
def __init__(self, input_length=10, hidden_length=20):
576+
super().__init__()
577+
self.input_length = input_length
570578
self.hidden_length = hidden_length
571579

572580
# reset gate components
573-
self.linear_reset_w1 = ComplexLinear(self.inp_length, self.hidden_length)
574-
self.linear_reset_r1 = ComplexLinear(self.hidden_length, self.hidden_length)
581+
self.linear_reset_w1 = ComplexLinear(
582+
self.input_length, self.hidden_length)
583+
self.linear_reset_r1 = ComplexLinear(
584+
self.hidden_length, self.hidden_length)
575585

576-
self.linear_reset_w2 = ComplexLinear(self.inp_length, self.hidden_length)
577-
self.linear_reset_r2 = ComplexLinear(self.hidden_length, self.hidden_length)
586+
self.linear_reset_w2 = ComplexLinear(
587+
self.input_length, self.hidden_length)
588+
self.linear_reset_r2 = ComplexLinear(
589+
self.hidden_length, self.hidden_length)
578590

579591
# update gate components
580-
self.linear_gate_w3 = ComplexLinear(self.inp_length, self.hidden_length)
581-
self.linear_gate_r3 = ComplexLinear(self.hidden_length, self.hidden_length)
592+
self.linear_gate_w3 = ComplexLinear(
593+
self.input_length, self.hidden_length)
594+
self.linear_gate_r3 = ComplexLinear(
595+
self.hidden_length, self.hidden_length)
582596

583597
self.activation_gate = ComplexSigmoid()
584598
self.activation_candidate = ComplexTanh()
@@ -615,6 +629,112 @@ def forward(self, x, h):
615629
n = self.update_component(x, h, r)
616630

617631
# Equation 4: the new hidden state
632+
633+
634+
class ComplexGRU(Module):
635+
def __init__(self, input_size, hidden_size, num_layers=1, bias=True,
636+
batch_first=False, dropout=0, bidirectional=False):
637+
super().__init__()
638+
639+
self.gru_re = GRU(input_size=input_size, hidden_size=hidden_size,
640+
num_layers=num_layers, bias=bias,
641+
batch_first=batch_first, dropout=dropout,
642+
bidirectional=bidirectional)
643+
self.gru_im = GRU(input_size=input_size, hidden_size=hidden_size,
644+
num_layers=num_layers, bias=bias,
645+
batch_first=batch_first, dropout=dropout,
646+
bidirectional=bidirectional)
647+
648+
def forward(self, x):
649+
real, state_real = self._forward_real(x)
650+
imaginary, state_imag = self._forward_imaginary(x)
651+
652+
output = torch.complex(real, imaginary)
653+
state = torch.complex(state_real, state_imag)
654+
655+
return output, state
656+
657+
def forward(self, x):
658+
r2r_out = self.gru_re(x.real)[0]
659+
r2i_out = self.gru_im(x.real)[0]
660+
i2r_out = self.gru_re(x.imag)[0]
661+
i2i_out = self.gru_im(x.imag)[0]
662+
real_out = r2r_out - i2i_out
663+
imag_out = i2r_out + r2i_out
664+
665+
return torch.complex(real_out, imag_out), None
666+
667+
def _forward_real(self, x):
668+
real_real, h_real = self.gru_re(x.real)
669+
imag_imag, h_imag = self.gru_im(x.imag)
670+
real = real_real - imag_imag
671+
672+
return real, torch.complex(h_real, h_imag)
673+
674+
def _forward_imaginary(self, x):
675+
imag_real, h_real = self.gru_re(x.imag)
676+
real_imag, h_imag = self.gru_im(x.real)
677+
imaginary = imag_real + real_imag
678+
679+
return imaginary, torch.complex(h_real, h_imag)
680+
681+
682+
class ComplexLSTM(Module):
683+
def __init__(self, input_size, hidden_size, num_layers=1, bias=True,
684+
batch_first=False, dropout=0, bidirectional=False):
685+
super().__init__()
686+
self.num_layer = num_layers
687+
self.hidden_size = hidden_size
688+
self.batch_dim = 0 if batch_first else 1
689+
self.bidirectional = bidirectional
690+
691+
self.lstm_re = LSTM(input_size=input_size, hidden_size=hidden_size,
692+
num_layers=num_layers, bias=bias,
693+
batch_first=batch_first, dropout=dropout,
694+
bidirectional=bidirectional)
695+
self.lstm_im = LSTM(input_size=input_size, hidden_size=hidden_size,
696+
num_layers=num_layers, bias=bias,
697+
batch_first=batch_first, dropout=dropout,
698+
bidirectional=bidirectional)
699+
def forward(self, x):
700+
real, state_real = self._forward_real(x)
701+
imaginary, state_imag = self._forward_imaginary(x)
702+
703+
output = torch.complex(real, imaginary)
704+
705+
return output, (state_real, state_imag)
706+
707+
def _forward_real(self, x):
708+
h_real, h_imag, c_real, c_imag = self._init_state(self._get_batch_size(x), x.is_cuda)
709+
real_real, (h_real, c_real) = self.lstm_re(x.real, (h_real, c_real))
710+
imag_imag, (h_imag, c_imag) = self.lstm_im(x.imag, (h_imag, c_imag))
711+
real = real_real - imag_imag
712+
return real, ((h_real, c_real), (h_imag, c_imag))
713+
714+
def _forward_imaginary(self, x):
715+
h_real, h_imag, c_real, c_imag = self._init_state(self._get_batch_size(x), x.is_cuda)
716+
imag_real, (h_real, c_real) = self.lstm_re(x.imag, (h_real, c_real))
717+
real_imag, (h_imag, c_imag) = self.lstm_im(x.real, (h_imag, c_imag))
718+
imaginary = imag_real + real_imag
719+
720+
return imaginary, ((h_real, c_real), (h_imag, c_imag))
721+
722+
def _init_state(self, batch_size, to_gpu=False):
723+
dim_0 = 2 if self.bidirectional else 1
724+
dims = (dim_0, batch_size, self.hidden_size)
725+
726+
h_real, h_imag, c_real, c_imag = [
727+
torch.zeros(dims) for i in range(4)]
728+
729+
if to_gpu:
730+
h_real, h_imag, c_real, c_imag = [
731+
t.cuda() for t in [h_real, h_imag, c_real, c_imag]]
732+
733+
734+
return h_real, h_imag, c_real, c_imag
735+
736+
def _get_batch_size(self, x):
737+
return x.size(self.batch_dim)
618738
h_new = (1 + complex_opposite(z)) * n + z * h # element-wise multiplication
619739

620740
return h_new

0 commit comments

Comments
 (0)