Skip to content

Commit 12f93a2

Browse files
authored
[BUGFIX] Fix performance hit for grouped convolutions (#216)
* Zero out conv weight matrices after resize * Improve speed of small grouped convolutions with single GEMM * Implement std::vector grouped_weights * Revert "Implement std::vector grouped_weights" This reverts commit e78e191. * Improve grouped convolutions for Conv1D by...ignoring them for now.
1 parent d53946a commit 12f93a2

2 files changed

Lines changed: 38 additions & 156 deletions

File tree

NAM/conv1d.cpp

Lines changed: 28 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,16 @@ void Conv1D::set_size_(const int in_channels, const int out_channels, const int
5555
this->_num_groups = groups;
5656
this->_weight.resize(kernel_size);
5757
for (size_t i = 0; i < this->_weight.size(); i++)
58+
{
5859
this->_weight[i].resize(out_channels,
5960
in_channels); // y = Ax, input array (C,L)
61+
this->_weight[i].setZero();
62+
}
6063
if (do_bias)
64+
{
6165
this->_bias.resize(out_channels);
66+
this->_bias.setZero();
67+
}
6268
else
6369
this->_bias.resize(0);
6470
this->_dilation = _dilation;
@@ -104,54 +110,22 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames)
104110
// Zero output before processing
105111
_output.leftCols(num_frames).setZero();
106112

107-
const int numGroups = this->_num_groups;
108-
const long in_channels = get_in_channels();
109-
const long out_channels = get_out_channels();
110-
const long in_per_group = in_channels / numGroups;
111-
const long out_per_group = out_channels / numGroups;
112-
113113
// Process from ring buffer with dilation lookback
114114
// After Write(), data is at positions [_write_pos, _write_pos+num_frames-1]
115115
// For kernel tap k with offset, we need to read from _write_pos + offset
116116
// The offset is negative (looking back), so _write_pos + offset reads from earlier positions
117-
// The original process_() reads: input.middleCols(i_start + offset, ncols)
118-
// where i_start is the current position and offset is negative for lookback
119-
120-
if (numGroups == 1)
117+
//
118+
// Grouped convolution note: The weight matrices are block-diagonal (zeros off-diagonal),
119+
// so we can use a single GEMM for all cases. A more advanced implementation could store
120+
// compact per-group weight matrices and loop over groups, but at typical model sizes
121+
// (e.g. 8 channels, 4 groups, 64 samples), the GEMM call overhead tends to dominate
122+
// and the single sparse GEMM approach is faster.
123+
for (size_t k = 0; k < this->_weight.size(); k++)
121124
{
122-
// Standard convolution (no grouping)
123-
for (size_t k = 0; k < this->_weight.size(); k++)
124-
{
125-
const long offset = this->_dilation * (k + 1 - (long)this->_weight.size());
126-
const long lookback = -offset;
127-
auto input_block = _input_buffer.Read(num_frames, lookback);
128-
_output.leftCols(num_frames).noalias() += this->_weight[k] * input_block;
129-
}
130-
}
131-
else
132-
{
133-
// Grouped convolution: process each group separately
134-
for (int g = 0; g < numGroups; g++)
135-
{
136-
for (size_t k = 0; k < this->_weight.size(); k++)
137-
{
138-
const long offset = this->_dilation * (k + 1 - (long)this->_weight.size());
139-
const long lookback = -offset;
140-
auto input_block = _input_buffer.Read(num_frames, lookback);
141-
142-
// Extract input slice for this group
143-
auto input_group = input_block.middleRows(g * in_per_group, in_per_group);
144-
145-
// Extract weight slice for this group
146-
auto weight_group = this->_weight[k].block(g * out_per_group, g * in_per_group, out_per_group, in_per_group);
147-
148-
// Extract output slice for this group
149-
auto output_group = _output.leftCols(num_frames).middleRows(g * out_per_group, out_per_group);
150-
151-
// Perform grouped convolution: output_group += weight_group * input_group
152-
output_group.noalias() += weight_group * input_group;
153-
}
154-
}
125+
const long offset = this->_dilation * (k + 1 - (long)this->_weight.size());
126+
const long lookback = -offset;
127+
auto input_block = _input_buffer.Read(num_frames, lookback);
128+
_output.leftCols(num_frames).noalias() += this->_weight[k] * input_block;
155129
}
156130

157131
// Add bias if present
@@ -167,49 +141,18 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames)
167141
void Conv1D::process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start, const long ncols,
168142
const long j_start) const
169143
{
170-
const int numGroups = this->_num_groups;
171-
const long in_channels = get_in_channels();
172-
const long out_channels = get_out_channels();
173-
const long in_per_group = in_channels / numGroups;
174-
const long out_per_group = out_channels / numGroups;
175-
176-
if (numGroups == 1)
177-
{
178-
// Standard convolution (no grouping)
179-
for (size_t k = 0; k < this->_weight.size(); k++)
180-
{
181-
const long offset = this->_dilation * (k + 1 - this->_weight.size());
182-
if (k == 0)
183-
output.middleCols(j_start, ncols).noalias() = this->_weight[k] * input.middleCols(i_start + offset, ncols);
184-
else
185-
output.middleCols(j_start, ncols).noalias() += this->_weight[k] * input.middleCols(i_start + offset, ncols);
186-
}
187-
}
188-
else
144+
// Grouped convolution note: The weight matrices are block-diagonal (zeros off-diagonal),
145+
// so we can use a single GEMM for all cases. A more advanced implementation could store
146+
// compact per-group weight matrices and loop over groups, but at typical model sizes
147+
// (e.g. 8 channels, 4 groups, 64 samples), the GEMM call overhead tends to dominate
148+
// and the single sparse GEMM approach is faster.
149+
for (size_t k = 0; k < this->_weight.size(); k++)
189150
{
190-
// Grouped convolution: process each group separately
191-
for (int g = 0; g < numGroups; g++)
192-
{
193-
for (size_t k = 0; k < this->_weight.size(); k++)
194-
{
195-
const long offset = this->_dilation * (k + 1 - this->_weight.size());
196-
197-
// Extract input slice for this group
198-
auto input_group = input.middleCols(i_start + offset, ncols).middleRows(g * in_per_group, in_per_group);
199-
200-
// Extract weight slice for this group
201-
auto weight_group = this->_weight[k].block(g * out_per_group, g * in_per_group, out_per_group, in_per_group);
202-
203-
// Extract output slice for this group
204-
auto output_group = output.middleCols(j_start, ncols).middleRows(g * out_per_group, out_per_group);
205-
206-
// Perform grouped convolution
207-
if (k == 0)
208-
output_group.noalias() = weight_group * input_group;
209-
else
210-
output_group.noalias() += weight_group * input_group;
211-
}
212-
}
151+
const long offset = this->_dilation * (k + 1 - this->_weight.size());
152+
if (k == 0)
153+
output.middleCols(j_start, ncols).noalias() = this->_weight[k] * input.middleCols(i_start + offset, ncols);
154+
else
155+
output.middleCols(j_start, ncols).noalias() += this->_weight[k] * input.middleCols(i_start + offset, ncols);
213156
}
214157
if (this->_bias.size() > 0)
215158
{

NAM/dsp.cpp

Lines changed: 10 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,13 @@ nam::Conv1x1::Conv1x1(const int in_channels, const int out_channels, const bool
332332

333333
this->_num_groups = groups;
334334
this->_weight.resize(out_channels, in_channels);
335+
this->_weight.setZero();
335336
this->_do_bias = _bias;
336337
if (_bias)
338+
{
337339
this->_bias.resize(out_channels);
340+
this->_bias.setZero();
341+
}
338342
}
339343

340344

@@ -374,45 +378,11 @@ void nam::Conv1x1::set_weights_(std::vector<float>::iterator& weights)
374378

375379
Eigen::MatrixXf nam::Conv1x1::process(const Eigen::MatrixXf& input, const int num_frames) const
376380
{
377-
const int numGroups = this->_num_groups;
378-
const long in_channels = get_in_channels();
379-
const long out_channels = get_out_channels();
380-
const long in_per_group = in_channels / numGroups;
381-
const long out_per_group = out_channels / numGroups;
382-
383-
Eigen::MatrixXf result(out_channels, num_frames);
384-
385-
if (numGroups == 1)
386-
{
387-
// Standard convolution (no grouping)
388-
if (this->_do_bias)
389-
result = (this->_weight * input.leftCols(num_frames)).colwise() + this->_bias;
390-
else
391-
result = this->_weight * input.leftCols(num_frames);
392-
}
393-
else
394-
{
395-
// Grouped convolution: process each group separately
396-
result.setZero();
397-
for (int g = 0; g < numGroups; g++)
398-
{
399-
// Extract input slice for this group
400-
auto input_group = input.leftCols(num_frames).middleRows(g * in_per_group, in_per_group);
401-
402-
// Extract weight slice for this group
403-
auto weight_group = this->_weight.block(g * out_per_group, g * in_per_group, out_per_group, in_per_group);
404-
405-
// Extract output slice for this group
406-
auto output_group = result.middleRows(g * out_per_group, out_per_group);
381+
// Single GEMM for all cases - block-diagonal zero structure handles grouping
382+
Eigen::MatrixXf result = this->_weight * input.leftCols(num_frames);
407383

408-
// Perform grouped convolution: output_group = weight_group * input_group
409-
output_group.noalias() = weight_group * input_group;
410-
}
411-
412-
// Add bias if present
413-
if (this->_do_bias)
414-
result.colwise() += this->_bias;
415-
}
384+
if (this->_do_bias)
385+
result.colwise() += this->_bias;
416386

417387
return result;
418388
}
@@ -421,40 +391,9 @@ void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, cons
421391
{
422392
assert(num_frames <= _output.cols());
423393

424-
const int numGroups = this->_num_groups;
425-
const long in_channels = get_in_channels();
426-
const long out_channels = get_out_channels();
427-
const long in_per_group = in_channels / numGroups;
428-
const long out_per_group = out_channels / numGroups;
429-
430-
if (numGroups == 1)
431-
{
432-
// Standard convolution (no grouping)
433-
_output.leftCols(num_frames).noalias() = this->_weight * input.leftCols(num_frames);
434-
}
435-
else
436-
{
437-
// Grouped convolution: process each group separately
438-
_output.leftCols(num_frames).setZero();
439-
for (int g = 0; g < numGroups; g++)
440-
{
441-
// Extract input slice for this group
442-
auto input_group = input.leftCols(num_frames).middleRows(g * in_per_group, in_per_group);
443-
444-
// Extract weight slice for this group
445-
auto weight_group = this->_weight.block(g * out_per_group, g * in_per_group, out_per_group, in_per_group);
446-
447-
// Extract output slice for this group
448-
auto output_group = _output.leftCols(num_frames).middleRows(g * out_per_group, out_per_group);
394+
// Single GEMM for all cases - block-diagonal zero structure handles grouping
395+
_output.leftCols(num_frames).noalias() = this->_weight * input.leftCols(num_frames);
449396

450-
// Perform grouped convolution: output_group = weight_group * input_group
451-
output_group.noalias() = weight_group * input_group;
452-
}
453-
}
454-
455-
// Add bias if present
456397
if (this->_do_bias)
457-
{
458398
_output.leftCols(num_frames).colwise() += this->_bias;
459-
}
460399
}

0 commit comments

Comments
 (0)