Skip to content

Commit f3cfd52

Browse files
authored
Merge pull request #4 from eagomez2/develop
Develop
2 parents f1d606b + 831c5cb commit f3cfd52

5 files changed

Lines changed: 85 additions & 8 deletions

File tree

docs/index.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,6 @@ By default, all methods support all modules as long as these are instances of `t
125125
| `torch.nn.AdaptiveMaxPool2d` | :material-check: | 0.0.1 |
126126
| `torch.nn.MaxPool1d` | :material-check: | 0.0.1 |
127127
| `torch.nn.MaxPool2d` | :material-check: | 0.0.1 |
128-
| `torch.nn.LayerNorm` | :material-check: | 0.0.1 |
128+
| `torch.nn.LayerNorm` | :material-check: | 0.0.1 |
129+
| `torch.nn.BatchNorm1d` | :material-check: | 0.0.4 |
130+
| `torch.nn.BatchNorm2d` | :material-check: | 0.0.4 |

docs/modules/layernorm.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Where
2222
## Complexity
2323
The complexity of a `torch.nn.LayerNorm` layer can be divided into two parts: The aggregated statistics calculation (i.e. mean and standard deviation) and the affine transformation applied by $\gamma$ and $\beta$ if `elementwise_affine=True`.
2424

25-
## Aggregated statistics
25+
### Aggregated statistics
2626
The complexity of the mean corresponds to the sum of all elements in the last $D$ dimensions of the input tensor $x$ and the division of that number by the total number of elements. As an example, if `normalized_shape=(3, 5)` then there are 14 additions and 1 division. This also corresponds to the product of the dimensions involved in `normalized_shape`.
2727

2828
$$
@@ -63,7 +63,7 @@ $$
6363
\end{equation}
6464
$$
6565

66-
## Elementwise affine
66+
### Elementwise affine
6767
If `elementwise_affine=True`, there is an element-wise multiplication by $\gamma$. If `bias=True`, there is also an element-wise addition by $\beta$. Therefore the whole complexity of affine transformations is
6868

6969
$$
@@ -82,15 +82,15 @@ $$
8282

8383
when `bias=True`.
8484

85-
## Batch size
85+
### Batch size
8686
So far we have not included the batch size $N$, which in this case could be defined as all other dimensions that are not $D$. This means, those that are not included in `normalized_shape`.
8787

8888
!!! note
8989
Please note that $N$ here corresponds to all dimensions not included in `normalized_shape`, which is different from the definition ot $N$ in `torch.var` which corresponds to the number of elements in the input tensor of that function.
9090

9191
The batch size $N$ multiplies all previously calculated operations by a factor $\eta$ corresponding to the multiplication of the remaining dimensions. For example, if the input tensor has size `(2, 3, 5)` and `normalized_shape=(3, 5)`, then $\eta$ is $2$.
9292

93-
## Total complexity
93+
### Total complexity
9494
Including all previously calculated factor, the total complexity can be summarized as
9595

9696
$$

src/moduleprofiler/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"Module profiler"
2-
__version__ = "0.0.3"
2+
__version__ = "0.0.4"
33

44
__all__ = [
55
"get_default_ops_map",

src/moduleprofiler/ops.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ def _default_ops_fn(
1515
return None
1616

1717

18+
def _excluded_ops_fn(
19+
module: nn.Module,
20+
input: Tuple[torch.Tensor],
21+
output: torch.Tensor
22+
) -> Any:
23+
return None
24+
25+
1826
def _identity_ops_fn(
1927
module: nn.Identity,
2028
input: Tuple[torch.Tensor],
@@ -584,11 +592,60 @@ def _avgpool2d_ops_fn(
584592
)
585593

586594

595+
def _batchnorm1d_ops_fn(
596+
module: nn.BatchNorm1d,
597+
input: Tuple[torch.Tensor],
598+
output: torch.Tensor
599+
) -> int:
600+
if input[0].ndim == 2:
601+
num_elements = input[0].size(0)
602+
603+
elif input[0].ndim == 3:
604+
num_elements = input[0].size(0) * input[0].size(-1)
605+
606+
else:
607+
raise AssertionError
608+
609+
if not module.affine:
610+
total_ops = 5 * num_elements + 4
611+
612+
else:
613+
total_ops = 7 * num_elements + 4
614+
615+
# Add num_features C
616+
total_ops *= module.num_features
617+
618+
return total_ops
619+
620+
621+
def _batchnorm2d_ops_fn(
622+
module: nn.BatchNorm2d,
623+
input: Tuple[torch.Tensor],
624+
output: torch.Tensor
625+
) -> int:
626+
num_elements = input[0].size(0) * input[0].size(-1) * input[0].size(-2)
627+
628+
if not module.affine:
629+
total_ops = 5 * num_elements + 4
630+
631+
else:
632+
total_ops = 7 * num_elements + 4
633+
634+
# Add num_features C
635+
total_ops *= module.num_features
636+
637+
return total_ops
638+
639+
640+
587641
def get_default_ops_map() -> dict:
588642
return {
589643
# Default method
590644
"default": _default_ops_fn,
591645

646+
# Excluded module method
647+
"excluded": _excluded_ops_fn,
648+
592649
# Layers
593650
nn.Identity: _identity_ops_fn,
594651
nn.Linear: _linear_ops_fn,
@@ -602,6 +659,8 @@ def get_default_ops_map() -> dict:
602659
nn.LSTM: _lstm_ops_fn,
603660

604661
# Norm
662+
nn.BatchNorm1d: _batchnorm1d_ops_fn,
663+
nn.BatchNorm2d: _batchnorm2d_ops_fn,
605664
nn.LayerNorm: _layernorm_ops_fn,
606665

607666
# Pooling

src/moduleprofiler/profiler.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Any,
77
Callable,
88
Dict,
9+
List,
910
Optional,
1011
Tuple,
1112
Union
@@ -43,6 +44,8 @@ class ModuleProfiler:
4344
their corresponding functions useed to trace the its size.
4445
ops_fn_map (dict): Dictionary containing a map between modules and
4546
their corresponding function to estimate the number of operations.
47+
exclude_from_ops (Optional[List[nn.Module]]): Modules to exclude from
48+
ops estimations.
4649
ts_fmt (str): Timestamp format used to print messages if
4750
`verbose=True`.
4851
verbose (bool): If ``True``, enabled verbose output mode.
@@ -56,6 +59,7 @@ def __init__(
5659
inference_end_attr: str = "__inference_end__",
5760
io_size_fn_map: dict = get_default_io_size_map(),
5861
ops_fn_map: dict = get_default_ops_map(),
62+
exclude_from_ops: Optional[List[nn.Module]] = None,
5963
ts_fmt: str = "%Y-%m-%d %H:%M:%S",
6064
verbose: bool = False
6165
) -> None:
@@ -69,6 +73,7 @@ def __init__(
6973
self.inference_end_attr = inference_end_attr
7074
self.io_size_fn_map = io_size_fn_map
7175
self.ops_fn_map = ops_fn_map
76+
self.exclude_from_ops = exclude_from_ops
7277
self.verbose = verbose
7378
self._logger = Logger(ts_fmt=ts_fmt)
7479
self._hook_handles = []
@@ -312,7 +317,14 @@ def _ops_fn(
312317
"""
313318
# Obtain method to estimate ops
314319
if module.__class__ in self.ops_fn_map:
315-
ops_fn = self.ops_fn_map[type(module)]
320+
if (
321+
self.exclude_from_ops is not None
322+
and module.__class__ in self.exclude_from_ops
323+
):
324+
ops_fn = self.ops_fn_map["excluded"]
325+
326+
else:
327+
ops_fn = self.ops_fn_map[type(module)]
316328

317329
else:
318330
ops_fn = self.ops_fn_map["default"]
@@ -368,7 +380,11 @@ def count_params(
368380
data[n] = {
369381
"type": m.__class__.__name__,
370382
"trainable_params": 0,
371-
"nontrainable_params": 0
383+
"trainable_params_dtype": None,
384+
"trainable_params_size_bits": 0,
385+
"nontrainable_params": 0,
386+
"nontrainable_params_dtype": None,
387+
"nontrainable_params_size_bits": 0
372388
}
373389

374390
for p in m.parameters():

0 commit comments

Comments
 (0)