Skip to content

Commit 405ffbb

Browse files
author
Esteban Gómez Mellado
committed
Fix ConvTranspose1d and ConvTranspose2d adds per filter estimation
1 parent ad62c1f commit 405ffbb

1 file changed

Lines changed: 46 additions & 25 deletions

File tree

src/moduleprofiler/ops.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -107,38 +107,32 @@ def _conv2d_ops_fn(
107107
return int(total_ops)
108108

109109

110-
def _convtransposend_filter_addition_ops(
111-
batch_size: int,
112-
module: Union[nn.ConvTranspose1d, nn.ConvTranspose2d],
110+
def _convtranspose1d_filter_addition_ops(
111+
module: nn.ConvTranspose1d,
113112
input: Tuple[torch.Tensor],
114-
output: torch.Tensor
113+
output: torch.Tensor
115114
) -> int:
116-
# Get input filled with ones
117-
x_ones = torch.ones_like(input[0])
115+
# Get input with same output size but filled with ones
116+
x_ones = torch.ones((1, 1, input[0].size(-1)))
118117

119-
# Get copy of input modules but with weight filled with ones
120-
convtransposend_ones = type(module)(
121-
in_channels=module.in_channels,
122-
out_channels=module.out_channels,
118+
# Get copy of input for single I/O channel to compute additions pattern
119+
convtranspose1d_ones = nn.ConvTranspose1d(
120+
in_channels=1,
121+
out_channels=1,
123122
kernel_size=module.kernel_size,
124123
stride=module.stride,
125124
padding=module.padding,
126125
padding_mode=module.padding_mode,
127126
dilation=module.dilation,
128-
groups=module.groups,
127+
groups=1,
129128
bias=False
130129
)
131-
torch.nn.init.ones_(convtransposend_ones.weight)
130+
torch.nn.init.ones_(convtranspose1d_ones.weight)
132131

133-
# Compute additions pattern
134-
total_addition_ops = convtransposend_ones(x_ones) - 1.0
132+
# Compute additions pattern for a single filter
133+
total_addition_ops = convtranspose1d_ones(x_ones) - 1.0
135134
total_addition_ops = torch.sum(total_addition_ops)
136135

137-
# NOTE: This number is for all filters and for the whole batch so it is
138-
# necessary to calculate for a single filter
139-
num_filters = (module.in_channels * module.out_channels) / module.groups
140-
total_addition_ops = (total_addition_ops / batch_size) / num_filters
141-
142136
return int(total_addition_ops)
143137

144138

@@ -154,8 +148,7 @@ def _convtranspose1d_ops_fn(
154148
batch_size = 1 if x0.ndim == 1 else x0.size(0)
155149

156150
# Get addition ops
157-
total_addition_ops = _convtransposend_filter_addition_ops(
158-
batch_size,
151+
adds_per_filter = _convtranspose1d_filter_addition_ops(
159152
module,
160153
input,
161154
output
@@ -164,7 +157,7 @@ def _convtranspose1d_ops_fn(
164157
total_ops = (
165158
batch_size
166159
* ((module.in_channels * module.out_channels) / module.groups)
167-
* (output.size(-1) * (module.kernel_size[0] + 1) + total_addition_ops)
160+
* (output.size(-1) * (module.kernel_size[0] + 1) + adds_per_filter)
168161
)
169162

170163
# Add bias correction
@@ -174,6 +167,35 @@ def _convtranspose1d_ops_fn(
174167
return int(total_ops)
175168

176169

170+
def _convtranspose2d_filter_addition_ops(
171+
module: nn.ConvTranspose1d,
172+
input: Tuple[torch.Tensor],
173+
output: torch.Tensor
174+
) -> int:
175+
# Get input with same output size but filled with ones
176+
x_ones = torch.ones((1, input[0].size(-2), input[0].size(-1)))
177+
178+
# Get copy of input for single I/O channel to compute additions pattern
179+
convtranspose2d_ones = nn.ConvTranspose2d(
180+
in_channels=1,
181+
out_channels=1,
182+
kernel_size=module.kernel_size,
183+
stride=module.stride,
184+
padding=module.padding,
185+
padding_mode=module.padding_mode,
186+
dilation=module.dilation,
187+
groups=1,
188+
bias=False
189+
)
190+
torch.nn.init.ones_(convtranspose2d_ones.weight)
191+
192+
# Compute additions pattern for a single filter
193+
total_addition_ops = convtranspose2d_ones(x_ones) - 1.0
194+
total_addition_ops = torch.sum(total_addition_ops)
195+
196+
return int(total_addition_ops)
197+
198+
177199
def _convtranspose2d_ops_fn(
178200
module: nn.ConvTranspose2d,
179201
input: Tuple[torch.Tensor],
@@ -186,8 +208,7 @@ def _convtranspose2d_ops_fn(
186208
batch_size = 1 if x0.ndim == 2 else x0.size(0)
187209

188210
# Get addition ops
189-
total_addition_ops = _convtransposend_filter_addition_ops(
190-
batch_size,
211+
adds_per_filter = _convtranspose2d_filter_addition_ops(
191212
module,
192213
input,
193214
output
@@ -199,7 +220,7 @@ def _convtranspose2d_ops_fn(
199220
* (
200221
output.size(-1) * output.size(-2)
201222
* (module.kernel_size[0] * module.kernel_size[1] + 1)
202-
+ total_addition_ops
223+
+ adds_per_filter
203224
)
204225
)
205226

0 commit comments

Comments
 (0)