@@ -584,6 +584,60 @@ def _avgpool2d_ops_fn(
584584 )
585585
586586
587+ def _batchnorm1d_ops_fn (
588+ module : nn .BatchNorm1d ,
589+ input : Tuple [torch .Tensor ],
590+ output : torch .Tensor
591+ ) -> int :
592+ if input [0 ].ndim == 2 :
593+ num_elements = input [0 ].size (0 )
594+
595+ elif input [0 ].ndim == 3 :
596+ num_elements = input [0 ].size (0 ) * input [0 ].size (- 1 )
597+
598+ else :
599+ raise AssertionError
600+
601+ if not module .affine :
602+ total_ops = 5 * num_elements + 4
603+
604+ else :
605+ if module .bias is not None :
606+ total_ops = 7 * num_elements + 4
607+
608+ else :
609+ total_ops = 6 * num_elements + 4
610+
611+ # Add num_features C
612+ total_ops *= module .num_features
613+
614+ return total_ops
615+
616+
617+ def _batchnorm2d_ops_fn (
618+ module : nn .BatchNorm2d ,
619+ input : Tuple [torch .Tensor ],
620+ output : torch .Tensor
621+ ) -> int :
622+ num_elements = input [0 ].size (0 ) * input [0 ].size (- 1 ) * input [0 ].size (- 2 )
623+
624+ if not module .affine :
625+ total_ops = 5 * num_elements + 4
626+
627+ else :
628+ if module .bias is not None :
629+ total_ops = 7 * num_elements + 4
630+
631+ else :
632+ total_ops = 6 * num_elements + 4
633+
634+ # Add num_features C
635+ total_ops *= module .num_features
636+
637+ return total_ops
638+
639+
640+
587641def get_default_ops_map () -> dict :
588642 return {
589643 # Default method
@@ -602,6 +656,8 @@ def get_default_ops_map() -> dict:
602656 nn .LSTM : _lstm_ops_fn ,
603657
604658 # Norm
659+ nn .BatchNorm1d : _batchnorm1d_ops_fn ,
660+ nn .BatchNorm2d : _batchnorm2d_ops_fn ,
605661 nn .LayerNorm : _layernorm_ops_fn ,
606662
607663 # Pooling
0 commit comments