66 Any ,
77 Callable ,
88 Dict ,
9+ List ,
910 Optional ,
1011 Tuple ,
1112 Union
@@ -43,6 +44,8 @@ class ModuleProfiler:
4344 their corresponding functions useed to trace the its size.
4445 ops_fn_map (dict): Dictionary containing a map between modules and
4546 their corresponding function to estimate the number of operations.
47+ exclude_from_ops (Optional[List[nn.Module]]): Modules to exclude from
48+ ops calculations.
4649 ts_fmt (str): Timestamp format used to print messages if
4750 `verbose=True`.
4851 verbose (bool): If ``True``, enabled verbose output mode.
@@ -56,6 +59,7 @@ def __init__(
5659 inference_end_attr : str = "__inference_end__" ,
5760 io_size_fn_map : dict = get_default_io_size_map (),
5861 ops_fn_map : dict = get_default_ops_map (),
62+ exclude_from_ops : Optional [List [nn .Module ]] = None ,
5963 ts_fmt : str = "%Y-%m-%d %H:%M:%S" ,
6064 verbose : bool = False
6165 ) -> None :
@@ -69,6 +73,7 @@ def __init__(
6973 self .inference_end_attr = inference_end_attr
7074 self .io_size_fn_map = io_size_fn_map
7175 self .ops_fn_map = ops_fn_map
76+ self .exclude_from_ops = exclude_from_ops
7277 self .verbose = verbose
7378 self ._logger = Logger (ts_fmt = ts_fmt )
7479 self ._hook_handles = []
@@ -312,7 +317,14 @@ def _ops_fn(
312317 """
313318 # Obtain method to estimate ops
314319 if module .__class__ in self .ops_fn_map :
315- ops_fn = self .ops_fn_map [type (module )]
320+ if (
321+ self .exclude_from_ops is not None
322+ and module .__class__ in self .exclude_from_ops
323+ ):
324+ ops_fn = self .ops_fn_map ["excluded" ]
325+
326+ else :
327+ ops_fn = self .ops_fn_map [type (module )]
316328
317329 else :
318330 ops_fn = self .ops_fn_map ["default" ]
0 commit comments