@@ -370,30 +370,25 @@ def __init__(
370370 device : Optional [str ] = None ,
371371 ) -> None :
372372 """Initialize blank mean, variance, count."""
373- self .mean = th .zeros (shape , device = device )
373+ self .running_mean = th .zeros (shape , device = device )
374374 self .M2 = th .zeros (shape , device = device )
375375 self .count = 0
376376
377- def update (self , x : th .Tensor ) -> None :
377+ def update (self , batch : th .Tensor ) -> None :
378378 """Update the mean and variance with a batch `x`."""
379379 with th .no_grad ():
380- batch_mean = th .mean (x , dim = 0 )
381- batch_var = th .var (x , dim = 0 , unbiased = False )
382- batch_count = x .shape [0 ]
383- batch_M2 = batch_var * batch_count
384- if self .count == 0 :
385- self .count = batch_count
386- self .mean = batch_mean
387- self .M2 = batch_M2
388- return
389-
390- delta = batch_mean - self .mean
391- total_count = self .count + batch_count
392- self .mean += delta * batch_count / total_count
393-
394- self .M2 += batch_M2 + delta * delta * self .count * batch_count / total_count
395-
396- self .count = total_count
380+ batch_mean = th .mean (batch , dim = 0 )
381+ batch_var = th .var (batch , dim = 0 , unbiased = False )
382+ batch_count = batch .shape [0 ]
383+
384+ delta = batch_mean - self .running_mean
385+ tot_count = self .count + batch_count
386+ self .running_mean += delta * batch_count / tot_count
387+
388+ self .M2 += batch_var * batch_count
389+ self .M2 += th .square (delta ) * self .count * batch_count / tot_count
390+
391+ self .count += batch_count
397392
398393 @property
399394 def var (self ) -> th .Tensor :
0 commit comments