@@ -206,7 +206,6 @@ class ComplexBatchNorm2d(_ComplexBatchNorm):
206206 def forward (self , input ):
207207 exponential_average_factor = 0.0
208208
209-
210209 if self .training and self .track_running_stats :
211210 if self .num_batches_tracked is not None :
212211 self .num_batches_tracked += 1
@@ -254,10 +253,6 @@ def forward(self, input):
254253 self .running_covar [:,2 ] = exponential_average_factor * Cri * n / (n - 1 )\
255254 + (1 - exponential_average_factor ) * self .running_covar [:,2 ]
256255
257-
258-
259-
260-
261256 # calculate the inverse square root the covariance matrix
262257 det = Crr * Cii - Cri .pow (2 )
263258 s = torch .sqrt (det )
@@ -353,17 +348,13 @@ def forward(self, input):
353348 del Crr , Cri , Cii , Rrr , Rii , Rri , det , s , t
354349 return input
355350
356- # class complexGruCell(Module):
357-
358- # def __init__(self):
359-
360- class ComplexGruCell (Module ):
351+ class ComplexGRUCell (Module ):
361352 """
362353 A GRU cell for complex-valued inputs
363354 """
364355
365356 def __init__ (self , input_length = 10 , hidden_length = 20 ):
366- super (ComplexGruCell , self ).__init__ ()
357+ super (ComplexGRUCell , self ).__init__ ()
367358 self .input_length = input_length
368359 self .hidden_length = hidden_length
369360
@@ -415,13 +406,13 @@ def forward(self, x, h):
415406
416407 return h_new
417408
418- class ComplexBNGruCell (Module ):
409+ class ComplexBNGRUCell (Module ):
419410 """
420411 A BN-GRU cell for complex-valued inputs
421412 """
422413
423414 def __init__ (self , input_length = 10 , hidden_length = 20 ):
424- super (ComplexBNGruCell , self ).__init__ ()
415+ super (ComplexBNGRUCell , self ).__init__ ()
425416 self .input_length = input_length
426417 self .hidden_length = hidden_length
427418
0 commit comments