1313from torch .nn import Module , Parameter , init
1414from torch .nn import Conv2d , Linear , BatchNorm1d , BatchNorm2d
1515from torch .nn import ConvTranspose2d
16+ < << << << HEAD :complexPyTorch / complexLayers .py
1617from .complexFunctions import complex_relu , complex_max_pool2d , complex_avg_pool2d
1718from .complexFunctions import complex_dropout , complex_dropout2d
19+ == == == =
20+ from complexFunctions import (
21+ complex_relu ,
22+ complex_tanh ,
23+ complex_sigmoid ,
24+ complex_max_pool2d ,
25+ complex_avg_pool2d ,
26+ complex_dropout ,
27+ complex_dropout2d ,
28+ complex_opposite ,
29+ )
30+ > >> >> >> d00e50c5d93386dd85b363a39c24cf783b7914aa :complexLayers .py
1831
1932def apply_complex (fr , fi , input , dtype = torch .complex64 ):
2033 return (fr (input .real )- fi (input .imag )).type (dtype ) \
@@ -83,6 +96,16 @@ class ComplexReLU(Module):
8396
8497 def forward (self ,input ):
8598 return complex_relu (input )
99+
100+ class ComplexSigmoid (Module ):
101+
102+ def forward (self ,input ):
103+ return complex_sigmoid (input )
104+
105+ class ComplexTanh (Module ):
106+
107+ def forward (self ,input ):
108+ return complex_tanh (input )
86109
87110class ComplexConvTranspose2d (Module ):
88111
@@ -342,3 +365,243 @@ def forward(self, input):
342365
343366 del Crr , Cri , Cii , Rrr , Rii , Rri , det , s , t
344367 return input
368+
369+ # class complexGruCell(Module):
370+
371+ # def __init__(self):
372+
373+ class ComplexGruCell (Module ):
374+ """
375+ A GRU cell for complex-valued inputs
376+ """
377+
378+ def __init__ (self , input_length = 10 , hidden_length = 20 ):
379+ super (ComplexGruCell , self ).__init__ ()
380+ self .input_length = input_length
381+ self .hidden_length = hidden_length
382+
383+ # reset gate components
384+ self .linear_reset_w1 = ComplexLinear (self .input_length , self .hidden_length )
385+ self .linear_reset_r1 = ComplexLinear (self .hidden_length , self .hidden_length )
386+
387+ self .linear_reset_w2 = ComplexLinear (self .input_length , self .hidden_length )
388+ self .linear_reset_r2 = ComplexLinear (self .hidden_length , self .hidden_length )
389+
390+ # update gate components
391+ self .linear_gate_w3 = ComplexLinear (self .input_length , self .hidden_length )
392+ self .linear_gate_r3 = ComplexLinear (self .hidden_length , self .hidden_length )
393+
394+ self .activation_gate = ComplexSigmoid ()
395+ self .activation_candidate = ComplexTanh ()
396+
397+ def reset_gate (self , x , h ):
398+ x_1 = self .linear_reset_w1 (x )
399+ h_1 = self .linear_reset_r1 (h )
400+ # gate update
401+ reset = self .activation_gate (x_1 + h_1 )
402+ return reset
403+
404+ def update_gate (self , x , h ):
405+ x_2 = self .linear_reset_w2 (x )
406+ h_2 = self .linear_reset_r2 (h )
407+ z = self .activation_gate (h_2 + x_2 )
408+ return z
409+
410+ def update_component (self , x , h , r ):
411+ x_3 = self .linear_gate_w3 (x )
412+ h_3 = r * self .linear_gate_r3 (h ) # element-wise multiplication
413+ gate_update = self .activation_candidate (x_3 + h_3 )
414+ return gate_update
415+
416+ def forward (self , x , h ):
417+ # Equation 1. reset gate vector
418+ r = self .reset_gate (x , h )
419+
420+ # Equation 2: the update gate - the shared update gate vector z
421+ z = self .update_gate (x , h )
422+
423+ # Equation 3: The almost output component
424+ n = self .update_component (x , h , r )
425+
426+ # Equation 4: the new hidden state
427+ h_new = (1 + complex_opposite (z )) * n + z * h # element-wise multiplication
428+
429+ return h_new
430+
431+ class ComplexBNGruCell (Module ):
432+ """
433+ A BN-GRU cell for complex-valued inputs
434+ """
435+
436+ def __init__ (self , input_length = 10 , hidden_length = 20 ):
437+ super (ComplexBNGruCell , self ).__init__ ()
438+ self .input_length = input_length
439+ self .hidden_length = hidden_length
440+
441+ # reset gate components
442+ self .linear_reset_w1 = ComplexLinear (self .input_length , self .hidden_length )
443+ self .linear_reset_r1 = ComplexLinear (self .hidden_length , self .hidden_length )
444+
445+ self .linear_reset_w2 = ComplexLinear (self .input_length , self .hidden_length )
446+ self .linear_reset_r2 = ComplexLinear (self .hidden_length , self .hidden_length )
447+
448+ # update gate components
449+ self .linear_gate_w3 = ComplexLinear (self .input_length , self .hidden_length )
450+ self .linear_gate_r3 = ComplexLinear (self .hidden_length , self .hidden_length )
451+
452+ self .activation_gate = ComplexSigmoid ()
453+ self .activation_candidate = ComplexTanh ()
454+
455+ self .bn = ComplexBatchNorm2d (1 )
456+
457+ def reset_gate (self , x , h ):
458+ x_1 = self .linear_reset_w1 (x )
459+ h_1 = self .linear_reset_r1 (h )
460+ # gate update
461+ reset = self .activation_gate (self .bn (x_1 ) + self .bn (h_1 ))
462+ return reset
463+
464+ def update_gate (self , x , h ):
465+ x_2 = self .linear_reset_w2 (x )
466+ h_2 = self .linear_reset_r2 (h )
467+ z = self .activation_gate (self .bn (h_2 ) + self .bn (x_2 ))
468+ return z
469+
470+ def update_component (self , x , h , r ):
471+ x_3 = self .linear_gate_w3 (x )
472+ h_3 = r * self .bn (self .linear_gate_r3 (h )) # element-wise multiplication
473+ gate_update = self .activation_candidate (self .bn (self .bn (x_3 ) + h_3 ))
474+ return gate_update
475+
476+ def forward (self , x , h ):
477+ # Equation 1. reset gate vector
478+ r = self .reset_gate (x , h )
479+
480+ # Equation 2: the update gate - the shared update gate vector z
481+ z = self .update_gate (x , h )
482+
483+ # Equation 3: The almost output component
484+ n = self .update_component (x , h , r )
485+
486+ # Equation 4: the new hidden state
487+ h_new = (1 + complex_opposite (z )) * n + z * h # element-wise multiplication
488+
489+ return h_new
490+
491+ class ComplexGRUCell (Module ):
492+ """
493+ A GRU cell for complex-valued inputs
494+ """
495+
496+ def __init__ (self , input_length = 10 , hidden_length = 20 ):
497+ super (ComplexGRUCell , self ).__init__ ()
498+ self .input_length = input_length
499+ self .hidden_length = hidden_length
500+
501+ # reset gate components
502+ self .linear_reset_w1 = ComplexLinear (self .input_length , self .hidden_length )
503+ self .linear_reset_r1 = ComplexLinear (self .hidden_length , self .hidden_length )
504+
505+ self .linear_reset_w2 = ComplexLinear (self .input_length , self .hidden_length )
506+ self .linear_reset_r2 = ComplexLinear (self .hidden_length , self .hidden_length )
507+
508+ # update gate components
509+ self .linear_gate_w3 = ComplexLinear (self .input_length , self .hidden_length )
510+ self .linear_gate_r3 = ComplexLinear (self .hidden_length , self .hidden_length )
511+
512+ self .activation_gate = ComplexSigmoid ()
513+ self .activation_candidate = ComplexTanh ()
514+
515+ def reset_gate (self , x , h ):
516+ x_1 = self .linear_reset_w1 (x )
517+ h_1 = self .linear_reset_r1 (h )
518+ # gate update
519+ reset = self .activation_gate (x_1 + h_1 )
520+ return reset
521+
522+ def update_gate (self , x , h ):
523+ x_2 = self .linear_reset_w2 (x )
524+ h_2 = self .linear_reset_r2 (h )
525+ z = self .activation_gate (h_2 + x_2 )
526+ return z
527+
528+ def update_component (self , x , h , r ):
529+ x_3 = self .linear_gate_w3 (x )
530+ h_3 = r * self .linear_gate_r3 (h ) # element-wise multiplication
531+ gate_update = self .activation_candidate (x_3 + h_3 )
532+ return gate_update
533+
534+ def forward (self , x , h ):
535+ # Equation 1. reset gate vector
536+ r = self .reset_gate (x , h )
537+
538+ # Equation 2: the update gate - the shared update gate vector z
539+ z = self .update_gate (x , h )
540+
541+ # Equation 3: The almost output component
542+ n = self .update_component (x , h , r )
543+
544+ # Equation 4: the new hidden state
545+ h_new = (1 + complex_opposite (z )) * n + z * h # element-wise multiplication
546+
547+ return h_new
548+
549+ class ComplexBNGRUCell (Module ):
550+ """
551+ A BN-GRU cell for complex-valued inputs
552+ """
553+
554+ def __init__ (self , input_length = 10 , hidden_length = 20 ):
555+ super (ComplexBNGRUCell , self ).__init__ ()
556+ self .input_length = input_length
557+ self .hidden_length = hidden_length
558+
559+ # reset gate components
560+ self .linear_reset_w1 = ComplexLinear (self .input_length , self .hidden_length )
561+ self .linear_reset_r1 = ComplexLinear (self .hidden_length , self .hidden_length )
562+
563+ self .linear_reset_w2 = ComplexLinear (self .input_length , self .hidden_length )
564+ self .linear_reset_r2 = ComplexLinear (self .hidden_length , self .hidden_length )
565+
566+ # update gate components
567+ self .linear_gate_w3 = ComplexLinear (self .input_length , self .hidden_length )
568+ self .linear_gate_r3 = ComplexLinear (self .hidden_length , self .hidden_length )
569+
570+ self .activation_gate = ComplexSigmoid ()
571+ self .activation_candidate = ComplexTanh ()
572+
573+ self .bn = ComplexBatchNorm2d (1 )
574+
575+ def reset_gate (self , x , h ):
576+ x_1 = self .linear_reset_w1 (x )
577+ h_1 = self .linear_reset_r1 (h )
578+ # gate update
579+ reset = self .activation_gate (self .bn (x_1 ) + self .bn (h_1 ))
580+ return reset
581+
582+ def update_gate (self , x , h ):
583+ x_2 = self .linear_reset_w2 (x )
584+ h_2 = self .linear_reset_r2 (h )
585+ z = self .activation_gate (self .bn (h_2 ) + self .bn (x_2 ))
586+ return z
587+
588+ def update_component (self , x , h , r ):
589+ x_3 = self .linear_gate_w3 (x )
590+ h_3 = r * self .bn (self .linear_gate_r3 (h )) # element-wise multiplication
591+ gate_update = self .activation_candidate (self .bn (self .bn (x_3 ) + h_3 ))
592+ return gate_update
593+
594+ def forward (self , x , h ):
595+ # Equation 1. reset gate vector
596+ r = self .reset_gate (x , h )
597+
598+ # Equation 2: the update gate - the shared update gate vector z
599+ z = self .update_gate (x , h )
600+
601+ # Equation 3: The almost output component
602+ n = self .update_component (x , h , r )
603+
604+ # Equation 4: the new hidden state
605+ h_new = (1 + complex_opposite (z )) * n + z * h # element-wise multiplication
606+
607+ return h_new
0 commit comments