@@ -55,10 +55,16 @@ void Conv1D::set_size_(const int in_channels, const int out_channels, const int
5555 this ->_num_groups = groups;
5656 this ->_weight .resize (kernel_size);
5757 for (size_t i = 0 ; i < this ->_weight .size (); i++)
58+ {
5859 this ->_weight [i].resize (out_channels,
5960 in_channels); // y = Ax, input array (C,L)
61+ this ->_weight [i].setZero ();
62+ }
6063 if (do_bias)
64+ {
6165 this ->_bias .resize (out_channels);
66+ this ->_bias .setZero ();
67+ }
6268 else
6369 this ->_bias .resize (0 );
6470 this ->_dilation = _dilation;
@@ -104,54 +110,22 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames)
104110 // Zero output before processing
105111 _output.leftCols (num_frames).setZero ();
106112
107- const int numGroups = this ->_num_groups ;
108- const long in_channels = get_in_channels ();
109- const long out_channels = get_out_channels ();
110- const long in_per_group = in_channels / numGroups;
111- const long out_per_group = out_channels / numGroups;
112-
113113 // Process from ring buffer with dilation lookback
114114 // After Write(), data is at positions [_write_pos, _write_pos+num_frames-1]
115115 // For kernel tap k with offset, we need to read from _write_pos + offset
116116 // The offset is negative (looking back), so _write_pos + offset reads from earlier positions
117- // The original process_() reads: input.middleCols(i_start + offset, ncols)
118- // where i_start is the current position and offset is negative for lookback
119-
120- if (numGroups == 1 )
117+ //
118+ // Grouped convolution note: The weight matrices are block-diagonal (zeros off-diagonal),
119+ // so we can use a single GEMM for all cases. A more advanced implementation could store
120+ // compact per-group weight matrices and loop over groups, but at typical model sizes
121+ // (e.g. 8 channels, 4 groups, 64 samples), the GEMM call overhead tends to dominate
122+ // and the single sparse GEMM approach is faster.
123+ for (size_t k = 0 ; k < this ->_weight .size (); k++)
121124 {
122- // Standard convolution (no grouping)
123- for (size_t k = 0 ; k < this ->_weight .size (); k++)
124- {
125- const long offset = this ->_dilation * (k + 1 - (long )this ->_weight .size ());
126- const long lookback = -offset;
127- auto input_block = _input_buffer.Read (num_frames, lookback);
128- _output.leftCols (num_frames).noalias () += this ->_weight [k] * input_block;
129- }
130- }
131- else
132- {
133- // Grouped convolution: process each group separately
134- for (int g = 0 ; g < numGroups; g++)
135- {
136- for (size_t k = 0 ; k < this ->_weight .size (); k++)
137- {
138- const long offset = this ->_dilation * (k + 1 - (long )this ->_weight .size ());
139- const long lookback = -offset;
140- auto input_block = _input_buffer.Read (num_frames, lookback);
141-
142- // Extract input slice for this group
143- auto input_group = input_block.middleRows (g * in_per_group, in_per_group);
144-
145- // Extract weight slice for this group
146- auto weight_group = this ->_weight [k].block (g * out_per_group, g * in_per_group, out_per_group, in_per_group);
147-
148- // Extract output slice for this group
149- auto output_group = _output.leftCols (num_frames).middleRows (g * out_per_group, out_per_group);
150-
151- // Perform grouped convolution: output_group += weight_group * input_group
152- output_group.noalias () += weight_group * input_group;
153- }
154- }
125+ const long offset = this ->_dilation * (k + 1 - (long )this ->_weight .size ());
126+ const long lookback = -offset;
127+ auto input_block = _input_buffer.Read (num_frames, lookback);
128+ _output.leftCols (num_frames).noalias () += this ->_weight [k] * input_block;
155129 }
156130
157131 // Add bias if present
@@ -167,49 +141,18 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames)
167141void Conv1D::process_ (const Eigen::MatrixXf& input, Eigen::MatrixXf& output, const long i_start, const long ncols,
168142 const long j_start) const
169143{
170- const int numGroups = this ->_num_groups ;
171- const long in_channels = get_in_channels ();
172- const long out_channels = get_out_channels ();
173- const long in_per_group = in_channels / numGroups;
174- const long out_per_group = out_channels / numGroups;
175-
176- if (numGroups == 1 )
177- {
178- // Standard convolution (no grouping)
179- for (size_t k = 0 ; k < this ->_weight .size (); k++)
180- {
181- const long offset = this ->_dilation * (k + 1 - this ->_weight .size ());
182- if (k == 0 )
183- output.middleCols (j_start, ncols).noalias () = this ->_weight [k] * input.middleCols (i_start + offset, ncols);
184- else
185- output.middleCols (j_start, ncols).noalias () += this ->_weight [k] * input.middleCols (i_start + offset, ncols);
186- }
187- }
188- else
144+ // Grouped convolution note: The weight matrices are block-diagonal (zeros off-diagonal),
145+ // so we can use a single GEMM for all cases. A more advanced implementation could store
146+ // compact per-group weight matrices and loop over groups, but at typical model sizes
147+ // (e.g. 8 channels, 4 groups, 64 samples), the GEMM call overhead tends to dominate
148+ // and the single sparse GEMM approach is faster.
149+ for (size_t k = 0 ; k < this ->_weight .size (); k++)
189150 {
190- // Grouped convolution: process each group separately
191- for (int g = 0 ; g < numGroups; g++)
192- {
193- for (size_t k = 0 ; k < this ->_weight .size (); k++)
194- {
195- const long offset = this ->_dilation * (k + 1 - this ->_weight .size ());
196-
197- // Extract input slice for this group
198- auto input_group = input.middleCols (i_start + offset, ncols).middleRows (g * in_per_group, in_per_group);
199-
200- // Extract weight slice for this group
201- auto weight_group = this ->_weight [k].block (g * out_per_group, g * in_per_group, out_per_group, in_per_group);
202-
203- // Extract output slice for this group
204- auto output_group = output.middleCols (j_start, ncols).middleRows (g * out_per_group, out_per_group);
205-
206- // Perform grouped convolution
207- if (k == 0 )
208- output_group.noalias () = weight_group * input_group;
209- else
210- output_group.noalias () += weight_group * input_group;
211- }
212- }
151+ const long offset = this ->_dilation * (k + 1 - this ->_weight .size ());
152+ if (k == 0 )
153+ output.middleCols (j_start, ncols).noalias () = this ->_weight [k] * input.middleCols (i_start + offset, ncols);
154+ else
155+ output.middleCols (j_start, ncols).noalias () += this ->_weight [k] * input.middleCols (i_start + offset, ncols);
213156 }
214157 if (this ->_bias .size () > 0 )
215158 {
0 commit comments