@@ -17,7 +17,10 @@ void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize)
1717 _input_mixin.SetMaxBufferSize (maxBufferSize);
1818 const long z_channels = this ->_conv .get_out_channels (); // This is 2*bottleneck when gated, bottleneck when not
1919 _z.resize (z_channels, maxBufferSize);
20- _1x1.SetMaxBufferSize (maxBufferSize);
20+ if (this ->_layer1x1 )
21+ {
22+ this ->_layer1x1 ->SetMaxBufferSize (maxBufferSize);
23+ }
2124 // Pre-allocate output buffers
2225 const long channels = this ->get_channels ();
2326 this ->_output_next_layer .resize (channels, maxBufferSize);
@@ -47,8 +50,8 @@ void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize)
4750 this ->_activation_pre_film ->SetMaxBufferSize (maxBufferSize);
4851 if (this ->_activation_post_film )
4952 this ->_activation_post_film ->SetMaxBufferSize (maxBufferSize);
50- if (this ->_1x1_post_film )
51- this ->_1x1_post_film ->SetMaxBufferSize (maxBufferSize);
53+ if (this ->_layer1x1_post_film )
54+ this ->_layer1x1_post_film ->SetMaxBufferSize (maxBufferSize);
5255 if (this ->_head1x1_post_film )
5356 this ->_head1x1_post_film ->SetMaxBufferSize (maxBufferSize);
5457}
@@ -57,7 +60,10 @@ void nam::wavenet::_Layer::set_weights_(std::vector<float>::iterator& weights)
5760{
5861 this ->_conv .set_weights_ (weights);
5962 this ->_input_mixin .set_weights_ (weights);
60- this ->_1x1 .set_weights_ (weights);
63+ if (this ->_layer1x1 )
64+ {
65+ this ->_layer1x1 ->set_weights_ (weights);
66+ }
6167 if (this ->_head1x1 )
6268 {
6369 this ->_head1x1 ->set_weights_ (weights);
@@ -75,8 +81,8 @@ void nam::wavenet::_Layer::set_weights_(std::vector<float>::iterator& weights)
7581 this ->_activation_pre_film ->set_weights_ (weights);
7682 if (this ->_activation_post_film )
7783 this ->_activation_post_film ->set_weights_ (weights);
78- if (this ->_1x1_post_film )
79- this ->_1x1_post_film ->set_weights_ (weights);
84+ if (this ->_layer1x1_post_film )
85+ this ->_layer1x1_post_film ->set_weights_ (weights);
8086 if (this ->_head1x1_post_film )
8187 this ->_head1x1_post_film ->set_weights_ (weights);
8288}
@@ -137,7 +143,10 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
137143 {
138144 this ->_activation_post_film ->Process_ (this ->_z , condition, num_frames);
139145 }
140- _1x1.process_ (_z, num_frames);
146+ if (this ->_layer1x1 )
147+ {
148+ this ->_layer1x1 ->process_ (this ->_z , num_frames);
149+ }
141150 }
142151 else if (this ->_gating_mode == GatingMode::GATED)
143152 {
@@ -153,7 +162,10 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
153162 this ->_z .topRows (bottleneck).leftCols (num_frames).noalias () =
154163 this ->_activation_post_film ->GetOutput ().leftCols (num_frames);
155164 }
156- _1x1.process_ (this ->_z .topRows (bottleneck), num_frames);
165+ if (this ->_layer1x1 )
166+ {
167+ this ->_layer1x1 ->process_ (this ->_z .topRows (bottleneck), num_frames);
168+ }
157169 }
158170 else if (this ->_gating_mode == GatingMode::BLENDED)
159171 {
@@ -169,11 +181,14 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
169181 this ->_z .topRows (bottleneck).leftCols (num_frames).noalias () =
170182 this ->_activation_post_film ->GetOutput ().leftCols (num_frames);
171183 }
172- _1x1.process_ (this ->_z .topRows (bottleneck), num_frames);
173- if (this ->_1x1_post_film )
184+ if (this ->_layer1x1 )
174185 {
175- Eigen::MatrixXf& _1x1_output = this ->_1x1 .GetOutput ();
176- this ->_1x1_post_film ->Process_ (_1x1_output, condition, num_frames);
186+ this ->_layer1x1 ->process_ (this ->_z .topRows (bottleneck), num_frames);
187+ if (this ->_layer1x1_post_film )
188+ {
189+ Eigen::MatrixXf& layer1x1_output = this ->_layer1x1 ->GetOutput ();
190+ this ->_layer1x1_post_film ->Process_ (layer1x1_output, condition, num_frames);
191+ }
177192 }
178193 }
179194
@@ -187,7 +202,6 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
187202 {
188203 this ->_head1x1 ->process_ (this ->_z .topRows (bottleneck).leftCols (num_frames), num_frames);
189204 }
190- this ->_head1x1 ->process (this ->_z .topRows (bottleneck).leftCols (num_frames), num_frames);
191205 if (this ->_head1x1_post_film )
192206 {
193207 Eigen::MatrixXf& head1x1_output = this ->_head1x1 ->GetOutput ();
@@ -205,9 +219,17 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
205219 this ->_output_head .leftCols (num_frames).noalias () = this ->_z .topRows (bottleneck).leftCols (num_frames);
206220 }
207221
208- // Store output to next layer (residual connection: input + _1x1 output)
209- this ->_output_next_layer .leftCols (num_frames).noalias () =
210- input.leftCols (num_frames) + _1x1.GetOutput ().leftCols (num_frames);
222+ // Store output to next layer (residual connection: input + layer1x1 output, or just input if layer1x1 inactive)
223+ if (this ->_layer1x1 )
224+ {
225+ this ->_output_next_layer .leftCols (num_frames).noalias () =
226+ input.leftCols (num_frames) + this ->_layer1x1 ->GetOutput ().leftCols (num_frames);
227+ }
228+ else
229+ {
230+ // If layer1x1 is inactive, residual connection is just the input (identity)
231+ this ->_output_next_layer .leftCols (num_frames).noalias () = input.leftCols (num_frames);
232+ }
211233}
212234
213235// LayerArray =================================================================
@@ -224,10 +246,10 @@ nam::wavenet::_LayerArray::_LayerArray(const LayerArrayParams& params)
224246 LayerParams layer_params (
225247 params.condition_size , params.channels , params.bottleneck , params.kernel_size , params.dilations [i],
226248 params.activation_configs [i], params.gating_modes [i], params.groups_input , params.groups_input_mixin ,
227- params.groups_1x1 , params.head1x1_params , params.secondary_activation_configs [i], params. conv_pre_film_params ,
228- params.conv_post_film_params , params.input_mixin_pre_film_params , params.input_mixin_post_film_params ,
229- params.activation_pre_film_params , params.activation_post_film_params , params._1x1_post_film_params ,
230- params.head1x1_post_film_params );
249+ params.layer1x1_params , params.head1x1_params , params.secondary_activation_configs [i],
250+ params.conv_pre_film_params , params.conv_post_film_params , params.input_mixin_pre_film_params ,
251+ params.input_mixin_post_film_params , params.activation_pre_film_params , params.activation_post_film_params ,
252+ params._layer1x1_post_film_params , params. head1x1_post_film_params );
231253 this ->_layers .push_back (_Layer (layer_params));
232254 }
233255}
@@ -570,11 +592,21 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
570592
571593 const int groups = layer_config.value (" groups_input" , 1 ); // defaults to 1
572594 const int groups_input_mixin = layer_config.value (" groups_input_mixin" , 1 ); // defaults to 1
573- const int groups_1x1 = layer_config.value (" groups_1x1" , 1 ); // defaults to 1
574595
575596 const int channels = layer_config[" channels" ];
576597 const int bottleneck = layer_config.value (" bottleneck" , channels); // defaults to channels if not present
577598
599+ // Parse layer1x1 parameters
600+ bool layer1x1_active = true ; // default to active if not present
601+ int layer1x1_groups = 1 ;
602+ if (layer_config.find (" layer1x1" ) != layer_config.end ())
603+ {
604+ const auto & layer1x1_config = layer_config[" layer1x1" ];
605+ layer1x1_active = layer1x1_config[" active" ]; // default to active
606+ layer1x1_groups = layer1x1_config[" groups" ];
607+ }
608+ nam::wavenet::Layer1x1Params layer1x1_params (layer1x1_active, layer1x1_groups);
609+
578610 const int input_size = layer_config[" input_size" ];
579611 const int condition_size = layer_config[" condition_size" ];
580612 const int head_size = layer_config[" head_size" ];
@@ -742,9 +774,9 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
742774 bool head1x1_active = false ;
743775 int head1x1_out_channels = channels;
744776 int head1x1_groups = 1 ;
745- if (layer_config.find (" head_1x1 " ) != layer_config.end ())
777+ if (layer_config.find (" head1x1 " ) != layer_config.end ())
746778 {
747- const auto & head1x1_config = layer_config[" head_1x1 " ];
779+ const auto & head1x1_config = layer_config[" head1x1 " ];
748780 head1x1_active = head1x1_config[" active" ];
749781 head1x1_out_channels = head1x1_config[" out_channels" ];
750782 head1x1_groups = head1x1_config[" groups" ];
@@ -771,15 +803,22 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
771803 nam::wavenet::_FiLMParams input_mixin_post_film_params = parse_film_params (" input_mixin_post_film" );
772804 nam::wavenet::_FiLMParams activation_pre_film_params = parse_film_params (" activation_pre_film" );
773805 nam::wavenet::_FiLMParams activation_post_film_params = parse_film_params (" activation_post_film" );
774- nam::wavenet::_FiLMParams _1x1_post_film_params = parse_film_params (" 1x1_post_film " );
806+ nam::wavenet::_FiLMParams _layer1x1_post_film_params = parse_film_params (" layer1x1_post_film " );
775807 nam::wavenet::_FiLMParams head1x1_post_film_params = parse_film_params (" head1x1_post_film" );
776808
809+ // Validation: if layer1x1_post_film is active, layer1x1 must also be active
810+ if (_layer1x1_post_film_params.active && !layer1x1_active)
811+ {
812+ throw std::runtime_error (" Layer array " + std::to_string (i)
813+ + " : layer1x1_post_film cannot be active when layer1x1.active is false" );
814+ }
815+
777816 layer_array_params.push_back (nam::wavenet::LayerArrayParams (
778817 input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations,
779- std::move (activation_configs), std::move (gating_modes), head_bias, groups, groups_input_mixin, groups_1x1 ,
818+ std::move (activation_configs), std::move (gating_modes), head_bias, groups, groups_input_mixin, layer1x1_params ,
780819 head1x1_params, std::move (secondary_activation_configs), conv_pre_film_params, conv_post_film_params,
781820 input_mixin_pre_film_params, input_mixin_post_film_params, activation_pre_film_params,
782- activation_post_film_params, _1x1_post_film_params , head1x1_post_film_params));
821+ activation_post_film_params, _layer1x1_post_film_params , head1x1_post_film_params));
783822 }
784823 const bool with_head = !config[" head" ].is_null ();
785824 const float head_scale = config[" head_scale" ];
0 commit comments