Skip to content

Commit 7b8ff95

Browse files
author
Esteban Gómez
committed
Add exclude_from_ops feature
1 parent 8e37044 commit 7b8ff95

2 files changed

Lines changed: 24 additions & 1 deletion

File tree

src/moduleprofiler/ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ def _default_ops_fn(
1515
return None
1616

1717

18+
def _excluded_ops_fn(
19+
module: nn.Module,
20+
input: Tuple[torch.Tensor],
21+
output: torch.Tensor
22+
) -> Any:
23+
return None
24+
25+
1826
def _identity_ops_fn(
1927
module: nn.Identity,
2028
input: Tuple[torch.Tensor],
@@ -635,6 +643,9 @@ def get_default_ops_map() -> dict:
635643
# Default method
636644
"default": _default_ops_fn,
637645

646+
# Excluded module method
647+
"excluded": _excluded_ops_fn,
648+
638649
# Layers
639650
nn.Identity: _identity_ops_fn,
640651
nn.Linear: _linear_ops_fn,

src/moduleprofiler/profiler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Any,
77
Callable,
88
Dict,
9+
List,
910
Optional,
1011
Tuple,
1112
Union
@@ -43,6 +44,8 @@ class ModuleProfiler:
4344
their corresponding functions useed to trace the its size.
4445
ops_fn_map (dict): Dictionary containing a map between modules and
4546
their corresponding function to estimate the number of operations.
47+
exclude_from_ops (Optional[List[nn.Module]]): Modules to exclude from
48+
ops calculations.
4649
ts_fmt (str): Timestamp format used to print messages if
4750
`verbose=True`.
4851
verbose (bool): If ``True``, enabled verbose output mode.
@@ -56,6 +59,7 @@ def __init__(
5659
inference_end_attr: str = "__inference_end__",
5760
io_size_fn_map: dict = get_default_io_size_map(),
5861
ops_fn_map: dict = get_default_ops_map(),
62+
exclude_from_ops: Optional[List[nn.Module]] = None,
5963
ts_fmt: str = "%Y-%m-%d %H:%M:%S",
6064
verbose: bool = False
6165
) -> None:
@@ -69,6 +73,7 @@ def __init__(
6973
self.inference_end_attr = inference_end_attr
7074
self.io_size_fn_map = io_size_fn_map
7175
self.ops_fn_map = ops_fn_map
76+
self.exclude_from_ops = exclude_from_ops
7277
self.verbose = verbose
7378
self._logger = Logger(ts_fmt=ts_fmt)
7479
self._hook_handles = []
@@ -312,7 +317,14 @@ def _ops_fn(
312317
"""
313318
# Obtain method to estimate ops
314319
if module.__class__ in self.ops_fn_map:
315-
ops_fn = self.ops_fn_map[type(module)]
320+
if (
321+
self.exclude_from_ops is not None
322+
and module.__class__ in self.exclude_from_ops
323+
):
324+
ops_fn = self.ops_fn_map["excluded"]
325+
326+
else:
327+
ops_fn = self.ops_fn_map[type(module)]
316328

317329
else:
318330
ops_fn = self.ops_fn_map["default"]

0 commit comments

Comments
 (0)