Skip to content

Commit d53946a

Browse files
authored
[FEATURE] WaveNet: Make layer1x1 (formerly 1x1) optional, rename .nam key "head_1x1" to "head1x1" (#214)
* Layer1x1 * Add tests for Layer1x1 functionality in WaveNet - Introduced a new test file `test_layer1x1.cpp` to validate the behavior of the Layer1x1 component in the WaveNet architecture. - Implemented multiple test cases to check both active and inactive states of Layer1x1, ensuring correct processing of inputs and outputs. - Added validation for error handling when the bottleneck does not match channels in inactive Layer1x1 configurations. - Enhanced tests to cover scenarios with grouped Layer1x1 convolutions and post-FiLM behavior, ensuring comprehensive coverage of the new functionality. * Fix up tests * Fix bug * Change .nam key from head_1x1 to head1x1
1 parent 7050f78 commit d53946a

13 files changed

Lines changed: 800 additions & 193 deletions

NAM/wavenet.cpp

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize)
1717
_input_mixin.SetMaxBufferSize(maxBufferSize);
1818
const long z_channels = this->_conv.get_out_channels(); // This is 2*bottleneck when gated, bottleneck when not
1919
_z.resize(z_channels, maxBufferSize);
20-
_1x1.SetMaxBufferSize(maxBufferSize);
20+
if (this->_layer1x1)
21+
{
22+
this->_layer1x1->SetMaxBufferSize(maxBufferSize);
23+
}
2124
// Pre-allocate output buffers
2225
const long channels = this->get_channels();
2326
this->_output_next_layer.resize(channels, maxBufferSize);
@@ -47,8 +50,8 @@ void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize)
4750
this->_activation_pre_film->SetMaxBufferSize(maxBufferSize);
4851
if (this->_activation_post_film)
4952
this->_activation_post_film->SetMaxBufferSize(maxBufferSize);
50-
if (this->_1x1_post_film)
51-
this->_1x1_post_film->SetMaxBufferSize(maxBufferSize);
53+
if (this->_layer1x1_post_film)
54+
this->_layer1x1_post_film->SetMaxBufferSize(maxBufferSize);
5255
if (this->_head1x1_post_film)
5356
this->_head1x1_post_film->SetMaxBufferSize(maxBufferSize);
5457
}
@@ -57,7 +60,10 @@ void nam::wavenet::_Layer::set_weights_(std::vector<float>::iterator& weights)
5760
{
5861
this->_conv.set_weights_(weights);
5962
this->_input_mixin.set_weights_(weights);
60-
this->_1x1.set_weights_(weights);
63+
if (this->_layer1x1)
64+
{
65+
this->_layer1x1->set_weights_(weights);
66+
}
6167
if (this->_head1x1)
6268
{
6369
this->_head1x1->set_weights_(weights);
@@ -75,8 +81,8 @@ void nam::wavenet::_Layer::set_weights_(std::vector<float>::iterator& weights)
7581
this->_activation_pre_film->set_weights_(weights);
7682
if (this->_activation_post_film)
7783
this->_activation_post_film->set_weights_(weights);
78-
if (this->_1x1_post_film)
79-
this->_1x1_post_film->set_weights_(weights);
84+
if (this->_layer1x1_post_film)
85+
this->_layer1x1_post_film->set_weights_(weights);
8086
if (this->_head1x1_post_film)
8187
this->_head1x1_post_film->set_weights_(weights);
8288
}
@@ -137,7 +143,10 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
137143
{
138144
this->_activation_post_film->Process_(this->_z, condition, num_frames);
139145
}
140-
_1x1.process_(_z, num_frames);
146+
if (this->_layer1x1)
147+
{
148+
this->_layer1x1->process_(this->_z, num_frames);
149+
}
141150
}
142151
else if (this->_gating_mode == GatingMode::GATED)
143152
{
@@ -153,7 +162,10 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
153162
this->_z.topRows(bottleneck).leftCols(num_frames).noalias() =
154163
this->_activation_post_film->GetOutput().leftCols(num_frames);
155164
}
156-
_1x1.process_(this->_z.topRows(bottleneck), num_frames);
165+
if (this->_layer1x1)
166+
{
167+
this->_layer1x1->process_(this->_z.topRows(bottleneck), num_frames);
168+
}
157169
}
158170
else if (this->_gating_mode == GatingMode::BLENDED)
159171
{
@@ -169,11 +181,14 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
169181
this->_z.topRows(bottleneck).leftCols(num_frames).noalias() =
170182
this->_activation_post_film->GetOutput().leftCols(num_frames);
171183
}
172-
_1x1.process_(this->_z.topRows(bottleneck), num_frames);
173-
if (this->_1x1_post_film)
184+
if (this->_layer1x1)
174185
{
175-
Eigen::MatrixXf& _1x1_output = this->_1x1.GetOutput();
176-
this->_1x1_post_film->Process_(_1x1_output, condition, num_frames);
186+
this->_layer1x1->process_(this->_z.topRows(bottleneck), num_frames);
187+
if (this->_layer1x1_post_film)
188+
{
189+
Eigen::MatrixXf& layer1x1_output = this->_layer1x1->GetOutput();
190+
this->_layer1x1_post_film->Process_(layer1x1_output, condition, num_frames);
191+
}
177192
}
178193
}
179194

@@ -187,7 +202,6 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
187202
{
188203
this->_head1x1->process_(this->_z.topRows(bottleneck).leftCols(num_frames), num_frames);
189204
}
190-
this->_head1x1->process(this->_z.topRows(bottleneck).leftCols(num_frames), num_frames);
191205
if (this->_head1x1_post_film)
192206
{
193207
Eigen::MatrixXf& head1x1_output = this->_head1x1->GetOutput();
@@ -205,9 +219,17 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
205219
this->_output_head.leftCols(num_frames).noalias() = this->_z.topRows(bottleneck).leftCols(num_frames);
206220
}
207221

208-
// Store output to next layer (residual connection: input + _1x1 output)
209-
this->_output_next_layer.leftCols(num_frames).noalias() =
210-
input.leftCols(num_frames) + _1x1.GetOutput().leftCols(num_frames);
222+
// Store output to next layer (residual connection: input + layer1x1 output, or just input if layer1x1 inactive)
223+
if (this->_layer1x1)
224+
{
225+
this->_output_next_layer.leftCols(num_frames).noalias() =
226+
input.leftCols(num_frames) + this->_layer1x1->GetOutput().leftCols(num_frames);
227+
}
228+
else
229+
{
230+
// If layer1x1 is inactive, residual connection is just the input (identity)
231+
this->_output_next_layer.leftCols(num_frames).noalias() = input.leftCols(num_frames);
232+
}
211233
}
212234

213235
// LayerArray =================================================================
@@ -224,10 +246,10 @@ nam::wavenet::_LayerArray::_LayerArray(const LayerArrayParams& params)
224246
LayerParams layer_params(
225247
params.condition_size, params.channels, params.bottleneck, params.kernel_size, params.dilations[i],
226248
params.activation_configs[i], params.gating_modes[i], params.groups_input, params.groups_input_mixin,
227-
params.groups_1x1, params.head1x1_params, params.secondary_activation_configs[i], params.conv_pre_film_params,
228-
params.conv_post_film_params, params.input_mixin_pre_film_params, params.input_mixin_post_film_params,
229-
params.activation_pre_film_params, params.activation_post_film_params, params._1x1_post_film_params,
230-
params.head1x1_post_film_params);
249+
params.layer1x1_params, params.head1x1_params, params.secondary_activation_configs[i],
250+
params.conv_pre_film_params, params.conv_post_film_params, params.input_mixin_pre_film_params,
251+
params.input_mixin_post_film_params, params.activation_pre_film_params, params.activation_post_film_params,
252+
params._layer1x1_post_film_params, params.head1x1_post_film_params);
231253
this->_layers.push_back(_Layer(layer_params));
232254
}
233255
}
@@ -570,11 +592,21 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
570592

571593
const int groups = layer_config.value("groups_input", 1); // defaults to 1
572594
const int groups_input_mixin = layer_config.value("groups_input_mixin", 1); // defaults to 1
573-
const int groups_1x1 = layer_config.value("groups_1x1", 1); // defaults to 1
574595

575596
const int channels = layer_config["channels"];
576597
const int bottleneck = layer_config.value("bottleneck", channels); // defaults to channels if not present
577598

599+
// Parse layer1x1 parameters
600+
bool layer1x1_active = true; // default to active if not present
601+
int layer1x1_groups = 1;
602+
if (layer_config.find("layer1x1") != layer_config.end())
603+
{
604+
const auto& layer1x1_config = layer_config["layer1x1"];
605+
layer1x1_active = layer1x1_config["active"]; // default to active
606+
layer1x1_groups = layer1x1_config["groups"];
607+
}
608+
nam::wavenet::Layer1x1Params layer1x1_params(layer1x1_active, layer1x1_groups);
609+
578610
const int input_size = layer_config["input_size"];
579611
const int condition_size = layer_config["condition_size"];
580612
const int head_size = layer_config["head_size"];
@@ -742,9 +774,9 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
742774
bool head1x1_active = false;
743775
int head1x1_out_channels = channels;
744776
int head1x1_groups = 1;
745-
if (layer_config.find("head_1x1") != layer_config.end())
777+
if (layer_config.find("head1x1") != layer_config.end())
746778
{
747-
const auto& head1x1_config = layer_config["head_1x1"];
779+
const auto& head1x1_config = layer_config["head1x1"];
748780
head1x1_active = head1x1_config["active"];
749781
head1x1_out_channels = head1x1_config["out_channels"];
750782
head1x1_groups = head1x1_config["groups"];
@@ -771,15 +803,22 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
771803
nam::wavenet::_FiLMParams input_mixin_post_film_params = parse_film_params("input_mixin_post_film");
772804
nam::wavenet::_FiLMParams activation_pre_film_params = parse_film_params("activation_pre_film");
773805
nam::wavenet::_FiLMParams activation_post_film_params = parse_film_params("activation_post_film");
774-
nam::wavenet::_FiLMParams _1x1_post_film_params = parse_film_params("1x1_post_film");
806+
nam::wavenet::_FiLMParams _layer1x1_post_film_params = parse_film_params("layer1x1_post_film");
775807
nam::wavenet::_FiLMParams head1x1_post_film_params = parse_film_params("head1x1_post_film");
776808

809+
// Validation: if layer1x1_post_film is active, layer1x1 must also be active
810+
if (_layer1x1_post_film_params.active && !layer1x1_active)
811+
{
812+
throw std::runtime_error("Layer array " + std::to_string(i)
813+
+ ": layer1x1_post_film cannot be active when layer1x1.active is false");
814+
}
815+
777816
layer_array_params.push_back(nam::wavenet::LayerArrayParams(
778817
input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations,
779-
std::move(activation_configs), std::move(gating_modes), head_bias, groups, groups_input_mixin, groups_1x1,
818+
std::move(activation_configs), std::move(gating_modes), head_bias, groups, groups_input_mixin, layer1x1_params,
780819
head1x1_params, std::move(secondary_activation_configs), conv_pre_film_params, conv_post_film_params,
781820
input_mixin_pre_film_params, input_mixin_post_film_params, activation_pre_film_params,
782-
activation_post_film_params, _1x1_post_film_params, head1x1_post_film_params));
821+
activation_post_film_params, _layer1x1_post_film_params, head1x1_post_film_params));
783822
}
784823
const bool with_head = !config["head"].is_null();
785824
const float head_scale = config["head_scale"];

0 commit comments

Comments
 (0)