Skip to content

Commit 4b0f3f6

Browse files
author
Esteban Gómez
committed
Add BatchNorm1d and BatchNorm2d ops support
1 parent 2f66e14 commit 4b0f3f6

1 file changed

Lines changed: 56 additions & 0 deletions

File tree

src/moduleprofiler/ops.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,60 @@ def _avgpool2d_ops_fn(
584584
)
585585

586586

587+
def _batchnorm1d_ops_fn(
588+
module: nn.BatchNorm1d,
589+
input: Tuple[torch.Tensor],
590+
output: torch.Tensor
591+
) -> int:
592+
if input[0].ndim == 2:
593+
num_elements = input[0].size(0)
594+
595+
elif input[0].ndim == 3:
596+
num_elements = input[0].size(0) * input[0].size(-1)
597+
598+
else:
599+
raise AssertionError
600+
601+
if not module.affine:
602+
total_ops = 5 * num_elements + 4
603+
604+
else:
605+
if module.bias is not None:
606+
total_ops = 7 * num_elements + 4
607+
608+
else:
609+
total_ops = 6 * num_elements + 4
610+
611+
# Add num_features C
612+
total_ops *= module.num_features
613+
614+
return total_ops
615+
616+
617+
def _batchnorm2d_ops_fn(
618+
module: nn.BatchNorm2d,
619+
input: Tuple[torch.Tensor],
620+
output: torch.Tensor
621+
) -> int:
622+
num_elements = input[0].size(0) * input[0].size(-1) * input[0].size(-2)
623+
624+
if not module.affine:
625+
total_ops = 5 * num_elements + 4
626+
627+
else:
628+
if module.bias is not None:
629+
total_ops = 7 * num_elements + 4
630+
631+
else:
632+
total_ops = 6 * num_elements + 4
633+
634+
# Add num_features C
635+
total_ops *= module.num_features
636+
637+
return total_ops
638+
639+
640+
587641
def get_default_ops_map() -> dict:
588642
return {
589643
# Default method
@@ -602,6 +656,8 @@ def get_default_ops_map() -> dict:
602656
nn.LSTM: _lstm_ops_fn,
603657

604658
# Norm
659+
nn.BatchNorm1d: _batchnorm1d_ops_fn,
660+
nn.BatchNorm2d: _batchnorm2d_ops_fn,
605661
nn.LayerNorm: _layernorm_ops_fn,
606662

607663
# Pooling

0 commit comments

Comments
 (0)