File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -531,11 +531,18 @@ def _layernorm_ops_fn(
531531 input : Tuple [torch .Tensor ],
532532 output : torch .Tensor
533533) -> int :
534- if not module .elementwise_affine :
535- num_elements = (
536- module .normalized_shape if isinstance (module .normalized_shape , int )
537- else math .prod (module .normalized_shape )
534+ num_elements = math .prod (module .normalized_shape )
535+
536+ if len (module .normalized_shape ) == input [0 ].ndim :
537+ batch_size = 1
538+
539+ else :
540+ batch_size_end_dim = input [0 ].ndim - len (module .normalized_shape ) - 1
541+ batch_size = math .prod (
542+ [input [0 ].size (n ) for n in range (batch_size_end_dim + 1 )]
538543 )
544+
545+ if not module .elementwise_affine :
539546 total_ops = 5 * num_elements + 4
540547
541548 else :
@@ -546,7 +553,7 @@ def _layernorm_ops_fn(
546553 total_ops = 6 * num_elements + 4
547554
548555 # Add batch size
549- total_ops *= input [ 0 ]. size ( 0 )
556+ total_ops *= batch_size
550557
551558 return total_ops
552559
You can’t perform that action at this time.
0 commit comments