@@ -83,6 +83,16 @@ class ComplexReLU(Module):
8383
8484 def forward (self ,input ):
8585 return complex_relu (input )
86+
87+ class ComplexSigmoid (Module ):
88+
89+ def forward (self ,input ):
90+ return complex_sigmoid (input )
91+
92+ class ComplexTanh (Module ):
93+
94+ def forward (self ,input ):
95+ return complex_tanh (input )
8696
8797class ComplexConvTranspose2d (Module ):
8898
@@ -342,3 +352,125 @@ def forward(self, input):
342352
343353 del Crr , Cri , Cii , Rrr , Rii , Rri , det , s , t
344354 return input
355+
356+ # class complexGruCell(Module):
357+
358+ # def __init__(self):
359+
360+ class ComplexGruCell (Module ):
361+ """
362+ A GRU cell for complex-valued inputs
363+ """
364+
365+ def __init__ (self , input_length = 10 , hidden_length = 20 ):
366+ super (ComplexGruCell , self ).__init__ ()
367+ self .input_length = input_length
368+ self .hidden_length = hidden_length
369+
370+ # reset gate components
371+ self .linear_reset_w1 = ComplexLinear (self .input_length , self .hidden_length )
372+ self .linear_reset_r1 = ComplexLinear (self .hidden_length , self .hidden_length )
373+
374+ self .linear_reset_w2 = ComplexLinear (self .input_length , self .hidden_length )
375+ self .linear_reset_r2 = ComplexLinear (self .hidden_length , self .hidden_length )
376+
377+ # update gate components
378+ self .linear_gate_w3 = ComplexLinear (self .input_length , self .hidden_length )
379+ self .linear_gate_r3 = ComplexLinear (self .hidden_length , self .hidden_length )
380+
381+ self .activation_gate = ComplexSigmoid ()
382+ self .activation_candidate = ComplexTanh ()
383+
384+ def reset_gate (self , x , h ):
385+ x_1 = self .linear_reset_w1 (x )
386+ h_1 = self .linear_reset_r1 (h )
387+ # gate update
388+ reset = self .activation_gate (x_1 + h_1 )
389+ return reset
390+
391+ def update_gate (self , x , h ):
392+ x_2 = self .linear_reset_w2 (x )
393+ h_2 = self .linear_reset_r2 (h )
394+ z = self .activation_gate (h_2 + x_2 )
395+ return z
396+
397+ def update_component (self , x , h , r ):
398+ x_3 = self .linear_gate_w3 (x )
399+ h_3 = r * self .linear_gate_r3 (h ) # element-wise multiplication
400+ gate_update = self .activation_candidate (x_3 + h_3 )
401+ return gate_update
402+
403+ def forward (self , x , h ):
404+ # Equation 1. reset gate vector
405+ r = self .reset_gate (x , h )
406+
407+ # Equation 2: the update gate - the shared update gate vector z
408+ z = self .update_gate (x , h )
409+
410+ # Equation 3: The almost output component
411+ n = self .update_component (x , h , r )
412+
413+ # Equation 4: the new hidden state
414+ h_new = (1 + complex_opposite (z )) * n + z * h # element-wise multiplication
415+
416+ return h_new
417+
418+ class ComplexBNGruCell (Module ):
419+ """
420+ A BN-GRU cell for complex-valued inputs
421+ """
422+
423+ def __init__ (self , input_length = 10 , hidden_length = 20 ):
424+ super (ComplexBNGruCell , self ).__init__ ()
425+ self .input_length = input_length
426+ self .hidden_length = hidden_length
427+
428+ # reset gate components
429+ self .linear_reset_w1 = ComplexLinear (self .input_length , self .hidden_length )
430+ self .linear_reset_r1 = ComplexLinear (self .hidden_length , self .hidden_length )
431+
432+ self .linear_reset_w2 = ComplexLinear (self .input_length , self .hidden_length )
433+ self .linear_reset_r2 = ComplexLinear (self .hidden_length , self .hidden_length )
434+
435+ # update gate components
436+ self .linear_gate_w3 = ComplexLinear (self .input_length , self .hidden_length )
437+ self .linear_gate_r3 = ComplexLinear (self .hidden_length , self .hidden_length )
438+
439+ self .activation_gate = ComplexSigmoid ()
440+ self .activation_candidate = ComplexTanh ()
441+
442+ self .bn = ComplexBatchNorm2d (1 )
443+
444+ def reset_gate (self , x , h ):
445+ x_1 = self .linear_reset_w1 (x )
446+ h_1 = self .linear_reset_r1 (h )
447+ # gate update
448+ reset = self .activation_gate (self .bn (x_1 ) + self .bn (h_1 ))
449+ return reset
450+
451+ def update_gate (self , x , h ):
452+ x_2 = self .linear_reset_w2 (x )
453+ h_2 = self .linear_reset_r2 (h )
454+ z = self .activation_gate (self .bn (h_2 ) + self .bn (x_2 ))
455+ return z
456+
457+ def update_component (self , x , h , r ):
458+ x_3 = self .linear_gate_w3 (x )
459+ h_3 = r * self .bn (self .linear_gate_r3 (h )) # element-wise multiplication
460+ gate_update = self .activation_candidate (self .bn (self .bn (x_3 ) + h_3 ))
461+ return gate_update
462+
463+ def forward (self , x , h ):
464+ # Equation 1. reset gate vector
465+ r = self .reset_gate (x , h )
466+
467+ # Equation 2: the update gate - the shared update gate vector z
468+ z = self .update_gate (x , h )
469+
470+ # Equation 3: The almost output component
471+ n = self .update_component (x , h , r )
472+
473+ # Equation 4: the new hidden state
474+ h_new = (1 + complex_opposite (z )) * n + z * h # element-wise multiplication
475+
476+ return h_new
0 commit comments