1111from typing import Optional
1212
1313import torch
14- from complexFunctions import (complex_avg_pool2d , complex_dropout ,
15- complex_dropout2d , complex_max_pool2d ,
16- complex_opposite , complex_relu , complex_sigmoid ,
17- complex_tanh )
18- from torch .nn import (BatchNorm1d , BatchNorm2d , Conv2d , ConvTranspose2d ,
19- Linear , Module , Parameter , init )
14+ from torch .nn import (
15+ Module , Parameter , init ,
16+ Conv2d , ConvTranspose2d , Linear , LSTM , GRU ,
17+ BatchNorm1d , BatchNorm2d ,
18+ PReLU
19+ )
20+
21+ from .complexFunctions import (
22+ complex_relu ,
23+ complex_tanh ,
24+ complex_sigmoid ,
25+ complex_max_pool2d ,
26+ complex_avg_pool2d ,
27+ complex_dropout ,
28+ complex_dropout2d ,
29+ complex_opposite ,
30+ )
31+
32+
33+ def apply_complex (fr , fi , input , dtype = torch .complex64 ):
34+ return (fr (input .real )- fi (input .imag )).type (dtype ) \
35+ + 1j * (fr (input .imag )+ fi (input .real )).type (dtype )
2036
2137
22- def apply_complex (fr , fi , inp , dtype = torch .complex64 ):
23- return (fr (inp .real ) - fi (inp .imag )).type (dtype ) + 1j * (
24- fr (inp .imag ) + fi (inp .real )
25- ).type (dtype )
26-
2738class ComplexDropout (Module ):
2839 def __init__ (self , p = 0.5 ):
29- super (ComplexDropout , self ).__init__ ()
40+ super ().__init__ ()
3041 self .p = p
3142
32- def forward (self , inp ):
43+ def forward (self , input ):
3344 if self .training :
34- return complex_dropout (inp , self .p )
45+ return complex_dropout (input , self .p )
3546 else :
3647 return inp
3748
3849
50+
3951class ComplexDropout2d (Module ):
4052 def __init__ (self , p = 0.5 ):
4153 super (ComplexDropout2d , self ).__init__ ()
@@ -89,8 +101,8 @@ def __init__(self,kernel_size, stride= None, padding = 0,
89101 self .count_include_pad = count_include_pad
90102 self .divisor_override = divisor_override
91103
92- def forward (self ,input ):
93- return complex_avg_pool2d (input ,kernel_size = self .kernel_size ,
104+ def forward (self ,inp ):
105+ return complex_avg_pool2d (inp ,kernel_size = self .kernel_size ,
94106 stride = self .stride , padding = self .padding ,
95107 ceil_mode = self .ceil_mode , count_include_pad = self .count_include_pad ,
96108 divisor_override = self .divisor_override )
@@ -106,6 +118,16 @@ class ComplexSigmoid(Module):
106118 @staticmethod
107119 def forward (inp ):
108120 return complex_sigmoid (inp )
121+
122+ class ComplexPReLU (Module ):
123+ def __init__ (self ):
124+ super ().__init__ ()
125+ self .r_prelu = PReLU ()
126+ self .i_prelu = PReLU ()
127+
128+ @staticmethod
129+ def forward (self , inp ):
130+ return self .r_prelu (inp .real ) + 1j * self .i_prelu (inp .imag )
109131
110132
111133class ComplexTanh (Module ):
@@ -129,32 +151,12 @@ def __init__(
129151 padding_mode = "zeros" ,
130152 ):
131153
132- super (ComplexConvTranspose2d , self ).__init__ ()
154+ super ().__init__ ()
133155
134- self .conv_tran_r = ConvTranspose2d (
135- in_channels ,
136- out_channels ,
137- kernel_size ,
138- stride ,
139- padding ,
140- output_padding ,
141- groups ,
142- bias ,
143- dilation ,
144- padding_mode ,
145- )
146- self .conv_tran_i = ConvTranspose2d (
147- in_channels ,
148- out_channels ,
149- kernel_size ,
150- stride ,
151- padding ,
152- output_padding ,
153- groups ,
154- bias ,
155- dilation ,
156- padding_mode ,
157- )
156+ self .conv_tran_r = ConvTranspose2d (in_channels , out_channels , kernel_size , stride , padding ,
157+ output_padding , groups , bias , dilation , padding_mode )
158+ self .conv_tran_i = ConvTranspose2d (in_channels , out_channels , kernel_size , stride , padding ,
159+ output_padding , groups , bias , dilation , padding_mode )
158160
159161 def forward (self , inp ):
160162 return apply_complex (self .conv_tran_r , self .conv_tran_i , inp )
@@ -200,7 +202,7 @@ def forward(self, inp):
200202
201203class ComplexLinear (Module ):
202204 def __init__ (self , in_features , out_features ):
203- super (ComplexLinear , self ).__init__ ()
205+ super ().__init__ ()
204206 self .fc_r = Linear (in_features , out_features )
205207 self .fc_i = Linear (in_features , out_features )
206208
@@ -315,7 +317,7 @@ def reset_parameters(self):
315317 init .constant_ (self .weight [:, :2 ], 1.4142135623730951 )
316318 init .zeros_ (self .weight [:, 2 ])
317319 init .zeros_ (self .bias )
318-
320+
319321
320322class ComplexBatchNorm2d (_ComplexBatchNorm ):
321323 def forward (self , inp ):
@@ -325,7 +327,8 @@ def forward(self, inp):
325327 if self .num_batches_tracked is not None :
326328 self .num_batches_tracked += 1
327329 if self .momentum is None : # use cumulative moving average
328- exponential_average_factor = 1.0 / float (self .num_batches_tracked )
330+ exponential_average_factor = 1.0 / \
331+ float (self .num_batches_tracked )
329332 else : # use exponential moving average
330333 exponential_average_factor = self .momentum
331334
@@ -405,7 +408,6 @@ def forward(self, inp):
405408 ).type (
406409 torch .complex64
407410 )
408-
409411 return inp
410412
411413
@@ -418,7 +420,8 @@ def forward(self, inp):
418420 if self .num_batches_tracked is not None :
419421 self .num_batches_tracked += 1
420422 if self .momentum is None : # use cumulative moving average
421- exponential_average_factor = 1.0 / float (self .num_batches_tracked )
423+ exponential_average_factor = 1.0 / \
424+ float (self .num_batches_tracked )
422425 else : # use exponential moving average
423426 exponential_average_factor = self .momentum
424427
@@ -500,26 +503,32 @@ def forward(self, inp):
500503 return inp
501504
502505
506+
503507class ComplexGRUCell (Module ):
504508 """
505509 A GRU cell for complex-valued inputs
506510 """
507-
508- def __init__ (self , inp_length = 10 , hidden_length = 20 ):
509- super (ComplexGRUCell , self ).__init__ ()
510- self .inp_length = inp_length
511+ def __init__ (self , input_length , hidden_length ):
512+ super ().__init__ ()
513+ self .input_length = input_length
511514 self .hidden_length = hidden_length
512515
513516 # reset gate components
514- self .linear_reset_w1 = ComplexLinear (self .inp_length , self .hidden_length )
515- self .linear_reset_r1 = ComplexLinear (self .hidden_length , self .hidden_length )
517+ self .linear_reset_w1 = ComplexLinear (
518+ self .input_length , self .hidden_length )
519+ self .linear_reset_r1 = ComplexLinear (
520+ self .hidden_length , self .hidden_length )
516521
517- self .linear_reset_w2 = ComplexLinear (self .inp_length , self .hidden_length )
518- self .linear_reset_r2 = ComplexLinear (self .hidden_length , self .hidden_length )
522+ self .linear_reset_w2 = ComplexLinear (
523+ self .input_length , self .hidden_length )
524+ self .linear_reset_r2 = ComplexLinear (
525+ self .hidden_length , self .hidden_length )
519526
520527 # update gate components
521- self .linear_gate_w3 = ComplexLinear (self .inp_length , self .hidden_length )
522- self .linear_gate_r3 = ComplexLinear (self .hidden_length , self .hidden_length )
528+ self .linear_gate_w3 = ComplexLinear (
529+ self .input_length , self .hidden_length )
530+ self .linear_gate_r3 = ComplexLinear (
531+ self .hidden_length , self .hidden_length )
523532
524533 self .activation_gate = ComplexSigmoid ()
525534 self .activation_candidate = ComplexTanh ()
@@ -555,30 +564,35 @@ def forward(self, x, h):
555564
556565 # Equation 4: the new hidden state
557566 h_new = (1 + complex_opposite (z )) * n + z * h # element-wise multiplication
558-
559567 return h_new
560568
561569
562570class ComplexBNGRUCell (Module ):
563571 """
564572 A BN-GRU cell for complex-valued inputs
565573 """
566-
567- def __init__ (self , inp_length = 10 , hidden_length = 20 ):
568- super (ComplexBNGRUCell , self ).__init__ ()
569- self .inp_length = inp_length
574+
575+ def __init__ (self , input_length = 10 , hidden_length = 20 ):
576+ super ().__init__ ()
577+ self .input_length = input_length
570578 self .hidden_length = hidden_length
571579
572580 # reset gate components
573- self .linear_reset_w1 = ComplexLinear (self .inp_length , self .hidden_length )
574- self .linear_reset_r1 = ComplexLinear (self .hidden_length , self .hidden_length )
581+ self .linear_reset_w1 = ComplexLinear (
582+ self .input_length , self .hidden_length )
583+ self .linear_reset_r1 = ComplexLinear (
584+ self .hidden_length , self .hidden_length )
575585
576- self .linear_reset_w2 = ComplexLinear (self .inp_length , self .hidden_length )
577- self .linear_reset_r2 = ComplexLinear (self .hidden_length , self .hidden_length )
586+ self .linear_reset_w2 = ComplexLinear (
587+ self .input_length , self .hidden_length )
588+ self .linear_reset_r2 = ComplexLinear (
589+ self .hidden_length , self .hidden_length )
578590
579591 # update gate components
580- self .linear_gate_w3 = ComplexLinear (self .inp_length , self .hidden_length )
581- self .linear_gate_r3 = ComplexLinear (self .hidden_length , self .hidden_length )
592+ self .linear_gate_w3 = ComplexLinear (
593+ self .input_length , self .hidden_length )
594+ self .linear_gate_r3 = ComplexLinear (
595+ self .hidden_length , self .hidden_length )
582596
583597 self .activation_gate = ComplexSigmoid ()
584598 self .activation_candidate = ComplexTanh ()
@@ -615,6 +629,112 @@ def forward(self, x, h):
615629 n = self .update_component (x , h , r )
616630
617631 # Equation 4: the new hidden state
632+
633+
634+ class ComplexGRU (Module ):
635+ def __init__ (self , input_size , hidden_size , num_layers = 1 , bias = True ,
636+ batch_first = False , dropout = 0 , bidirectional = False ):
637+ super ().__init__ ()
638+
639+ self .gru_re = GRU (input_size = input_size , hidden_size = hidden_size ,
640+ num_layers = num_layers , bias = bias ,
641+ batch_first = batch_first , dropout = dropout ,
642+ bidirectional = bidirectional )
643+ self .gru_im = GRU (input_size = input_size , hidden_size = hidden_size ,
644+ num_layers = num_layers , bias = bias ,
645+ batch_first = batch_first , dropout = dropout ,
646+ bidirectional = bidirectional )
647+
648+ def forward (self , x ):
649+ real , state_real = self ._forward_real (x )
650+ imaginary , state_imag = self ._forward_imaginary (x )
651+
652+ output = torch .complex (real , imaginary )
653+ state = torch .complex (state_real , state_imag )
654+
655+ return output , state
656+
657+ def forward (self , x ):
658+ r2r_out = self .gru_re (x .real )[0 ]
659+ r2i_out = self .gru_im (x .real )[0 ]
660+ i2r_out = self .gru_re (x .imag )[0 ]
661+ i2i_out = self .gru_im (x .imag )[0 ]
662+ real_out = r2r_out - i2i_out
663+ imag_out = i2r_out + r2i_out
664+
665+ return torch .complex (real_out , imag_out ), None
666+
667+ def _forward_real (self , x ):
668+ real_real , h_real = self .gru_re (x .real )
669+ imag_imag , h_imag = self .gru_im (x .imag )
670+ real = real_real - imag_imag
671+
672+ return real , torch .complex (h_real , h_imag )
673+
674+ def _forward_imaginary (self , x ):
675+ imag_real , h_real = self .gru_re (x .imag )
676+ real_imag , h_imag = self .gru_im (x .real )
677+ imaginary = imag_real + real_imag
678+
679+ return imaginary , torch .complex (h_real , h_imag )
680+
681+
682+ class ComplexLSTM (Module ):
683+ def __init__ (self , input_size , hidden_size , num_layers = 1 , bias = True ,
684+ batch_first = False , dropout = 0 , bidirectional = False ):
685+ super ().__init__ ()
686+ self .num_layer = num_layers
687+ self .hidden_size = hidden_size
688+ self .batch_dim = 0 if batch_first else 1
689+ self .bidirectional = bidirectional
690+
691+ self .lstm_re = LSTM (input_size = input_size , hidden_size = hidden_size ,
692+ num_layers = num_layers , bias = bias ,
693+ batch_first = batch_first , dropout = dropout ,
694+ bidirectional = bidirectional )
695+ self .lstm_im = LSTM (input_size = input_size , hidden_size = hidden_size ,
696+ num_layers = num_layers , bias = bias ,
697+ batch_first = batch_first , dropout = dropout ,
698+ bidirectional = bidirectional )
699+ def forward (self , x ):
700+ real , state_real = self ._forward_real (x )
701+ imaginary , state_imag = self ._forward_imaginary (x )
702+
703+ output = torch .complex (real , imaginary )
704+
705+ return output , (state_real , state_imag )
706+
707+ def _forward_real (self , x ):
708+ h_real , h_imag , c_real , c_imag = self ._init_state (self ._get_batch_size (x ), x .is_cuda )
709+ real_real , (h_real , c_real ) = self .lstm_re (x .real , (h_real , c_real ))
710+ imag_imag , (h_imag , c_imag ) = self .lstm_im (x .imag , (h_imag , c_imag ))
711+ real = real_real - imag_imag
712+ return real , ((h_real , c_real ), (h_imag , c_imag ))
713+
714+ def _forward_imaginary (self , x ):
715+ h_real , h_imag , c_real , c_imag = self ._init_state (self ._get_batch_size (x ), x .is_cuda )
716+ imag_real , (h_real , c_real ) = self .lstm_re (x .imag , (h_real , c_real ))
717+ real_imag , (h_imag , c_imag ) = self .lstm_im (x .real , (h_imag , c_imag ))
718+ imaginary = imag_real + real_imag
719+
720+ return imaginary , ((h_real , c_real ), (h_imag , c_imag ))
721+
722+ def _init_state (self , batch_size , to_gpu = False ):
723+ dim_0 = 2 if self .bidirectional else 1
724+ dims = (dim_0 , batch_size , self .hidden_size )
725+
726+ h_real , h_imag , c_real , c_imag = [
727+ torch .zeros (dims ) for i in range (4 )]
728+
729+ if to_gpu :
730+ h_real , h_imag , c_real , c_imag = [
731+ t .cuda () for t in [h_real , h_imag , c_real , c_imag ]]
732+
733+
734+ return h_real , h_imag , c_real , c_imag
735+
736+ def _get_batch_size (self , x ):
737+ return x .size (self .batch_dim )
618738 h_new = (1 + complex_opposite (z )) * n + z * h # element-wise multiplication
619739
620740 return h_new
0 commit comments