@@ -107,38 +107,32 @@ def _conv2d_ops_fn(
107107 return int (total_ops )
108108
109109
110- def _convtransposend_filter_addition_ops (
111- batch_size : int ,
112- module : Union [nn .ConvTranspose1d , nn .ConvTranspose2d ],
110+ def _convtranspose1d_filter_addition_ops (
111+ module : nn .ConvTranspose1d ,
113112 input : Tuple [torch .Tensor ],
114- output : torch .Tensor
113+ output : torch .Tensor
115114) -> int :
116- # Get input filled with ones
117- x_ones = torch .ones_like ( input [0 ])
115+ # Get input with same output size but filled with ones
116+ x_ones = torch .ones (( 1 , 1 , input [0 ]. size ( - 1 )) )
118117
119- # Get copy of input modules but with weight filled with ones
120- convtransposend_ones = type ( module ) (
121- in_channels = module . in_channels ,
122- out_channels = module . out_channels ,
118+ # Get copy of input for single I/O channel to compute additions pattern
119+ convtranspose1d_ones = nn . ConvTranspose1d (
120+ in_channels = 1 ,
121+ out_channels = 1 ,
123122 kernel_size = module .kernel_size ,
124123 stride = module .stride ,
125124 padding = module .padding ,
126125 padding_mode = module .padding_mode ,
127126 dilation = module .dilation ,
128- groups = module . groups ,
127+ groups = 1 ,
129128 bias = False
130129 )
131- torch .nn .init .ones_ (convtransposend_ones .weight )
130+ torch .nn .init .ones_ (convtranspose1d_ones .weight )
132131
133- # Compute additions pattern
134- total_addition_ops = convtransposend_ones (x_ones ) - 1.0
132+ # Compute additions pattern for a single filter
133+ total_addition_ops = convtranspose1d_ones (x_ones ) - 1.0
135134 total_addition_ops = torch .sum (total_addition_ops )
136135
137- # NOTE: This number is for all filters and for the whole batch so it is
138- # necessary to calculate for a single filter
139- num_filters = (module .in_channels * module .out_channels ) / module .groups
140- total_addition_ops = (total_addition_ops / batch_size ) / num_filters
141-
142136 return int (total_addition_ops )
143137
144138
@@ -154,8 +148,7 @@ def _convtranspose1d_ops_fn(
154148 batch_size = 1 if x0 .ndim == 1 else x0 .size (0 )
155149
156150 # Get addition ops
157- total_addition_ops = _convtransposend_filter_addition_ops (
158- batch_size ,
151+ adds_per_filter = _convtranspose1d_filter_addition_ops (
159152 module ,
160153 input ,
161154 output
@@ -164,7 +157,7 @@ def _convtranspose1d_ops_fn(
164157 total_ops = (
165158 batch_size
166159 * ((module .in_channels * module .out_channels ) / module .groups )
167- * (output .size (- 1 ) * (module .kernel_size [0 ] + 1 ) + total_addition_ops )
160+ * (output .size (- 1 ) * (module .kernel_size [0 ] + 1 ) + adds_per_filter )
168161 )
169162
170163 # Add bias correction
@@ -174,6 +167,35 @@ def _convtranspose1d_ops_fn(
174167 return int (total_ops )
175168
176169
170+ def _convtranspose2d_filter_addition_ops (
171+ module : nn .ConvTranspose1d ,
172+ input : Tuple [torch .Tensor ],
173+ output : torch .Tensor
174+ ) -> int :
175+ # Get input with same output size but filled with ones
176+ x_ones = torch .ones ((1 , input [0 ].size (- 2 ), input [0 ].size (- 1 )))
177+
178+ # Get copy of input for single I/O channel to compute additions pattern
179+ convtranspose2d_ones = nn .ConvTranspose2d (
180+ in_channels = 1 ,
181+ out_channels = 1 ,
182+ kernel_size = module .kernel_size ,
183+ stride = module .stride ,
184+ padding = module .padding ,
185+ padding_mode = module .padding_mode ,
186+ dilation = module .dilation ,
187+ groups = 1 ,
188+ bias = False
189+ )
190+ torch .nn .init .ones_ (convtranspose2d_ones .weight )
191+
192+ # Compute additions pattern for a single filter
193+ total_addition_ops = convtranspose2d_ones (x_ones ) - 1.0
194+ total_addition_ops = torch .sum (total_addition_ops )
195+
196+ return int (total_addition_ops )
197+
198+
177199def _convtranspose2d_ops_fn (
178200 module : nn .ConvTranspose2d ,
179201 input : Tuple [torch .Tensor ],
@@ -186,8 +208,7 @@ def _convtranspose2d_ops_fn(
186208 batch_size = 1 if x0 .ndim == 2 else x0 .size (0 )
187209
188210 # Get addition ops
189- total_addition_ops = _convtransposend_filter_addition_ops (
190- batch_size ,
211+ adds_per_filter = _convtranspose2d_filter_addition_ops (
191212 module ,
192213 input ,
193214 output
@@ -199,7 +220,7 @@ def _convtranspose2d_ops_fn(
199220 * (
200221 output .size (- 1 ) * output .size (- 2 )
201222 * (module .kernel_size [0 ] * module .kernel_size [1 ] + 1 )
202- + total_addition_ops
223+ + adds_per_filter
203224 )
204225 )
205226
0 commit comments