Skip to content

Commit 3f22218

Browse files
author
Esteban Gómez
committed
Fix _layernorm_ops_fn
1 parent 6c412d0 commit 3f22218

1 file changed

Lines changed: 12 additions & 5 deletions

File tree

src/moduleprofiler/ops.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -531,11 +531,18 @@ def _layernorm_ops_fn(
531531
input: Tuple[torch.Tensor],
532532
output: torch.Tensor
533533
) -> int:
534-
if not module.elementwise_affine:
535-
num_elements = (
536-
module.normalized_shape if isinstance(module.normalized_shape, int)
537-
else math.prod(module.normalized_shape)
534+
num_elements = math.prod(module.normalized_shape)
535+
536+
if len(module.normalized_shape) == input[0].ndim:
537+
batch_size = 1
538+
539+
else:
540+
batch_size_end_dim = input[0].ndim - len(module.normalized_shape) - 1
541+
batch_size = math.prod(
542+
[input[0].size(n) for n in range(batch_size_end_dim + 1)]
538543
)
544+
545+
if not module.elementwise_affine:
539546
total_ops = 5 * num_elements + 4
540547

541548
else:
@@ -546,7 +553,7 @@ def _layernorm_ops_fn(
546553
total_ops = 6 * num_elements + 4
547554

548555
# Add batch size
549-
total_ops *= input[0].size(0)
556+
total_ops *= batch_size
550557

551558
return total_ops
552559

0 commit comments

Comments
 (0)