Skip to content

Commit be317cf

Browse files
authored
Fix bugs, add an end-to-end test with a model with all new features (#198)
* Assert no post-head 1x1 FiLM if there's no head 1x1 * Add groups_input_mixin parameter to wavenet Layer Adds groups_input_mixin parameter to control grouped convolutions in the input_mixin Conv1x1 layer. The parameter is propagated through Layer, LayerArrayParams, and LayerArray constructors. Factory parsing defaults to 1 if not specified in the model JSON for backward compatibility. Also fixes a bug in test_real_time_safe where make_layer_all_films was incorrectly activating head1x1_post_film when head1x1 was inactive. * Change JSON key from 'groups' to 'groups_input' in WaveNet factory Aligns the JSON configuration key with the LayerArrayParams attribute name for consistency. The factory now reads 'groups_input' instead of 'groups' from the layer configuration. * Consolidate gating_activation_post_film_params with activation_post_film_params Removed the separate gating_activation_post_film_params parameter and now use activation_post_film_params for both gated and blended modes. This simplifies the API and reduces redundancy since both modes apply FiLM modulation after activation in the same way. Changes: - Removed gating_activation_post_film_params parameter from _Layer, _LayerArray, and LayerArrayParams constructors - Removed _gating_activation_post_film member variable from _Layer - Updated _Layer::Process() to use _activation_post_film for gated mode - Updated all test files to use 7 FiLM parameters instead of 8 - Updated weight count in test_real_time_safe.cpp accordingly * Refactor secondary_activation to use ActivationConfig Update WaveNet C++ code to handle secondary_activation as ActivationConfig instead of string for proper type safety. This enables support for complex activation types with parameters (e.g., PReLU, LeakyHardtanh). Changes: - Modify _Layer, _LayerArray, and LayerArrayParams to use typed ActivationConfig - Update Factory function to parse secondary_activation from JSON as ActivationConfig - Update all test files to use ActivationConfig for secondary activation parameters All tests pass successfully. * Fix Conv1D and Conv1x1 to use groups parameters Fixed two bugs in _Layer constructor: - Conv1D was missing groups_input parameter (always defaulted to 1) - Conv1x1 _1x1 was passing groups_1x1 as bias parameter instead of groups These fixes enable proper grouped convolutions for reduced computation. * Add wavenet_a2_max.nam * Add wavenet_a2_max.nam to end-to-end tests, formatting * Add real-time safety tests for FiLM. * Add test for RT safety for Layer with gated activation and post-activation FiLM. Failing. * Fix test_layer_post_activation_film_gated_realtime_safe test errors - Fix incorrect parameter comments (lines 713-720): corrected parameter names to match actual Layer constructor - Fix misleading comment on activation_post_film weight calculation: clarify that FiLM is created with bottleneck as input_dim, shift doubles output channels - Remove 4 extra placeholder weights that were causing assertion failures - Apply same fixes to test_layer_post_activation_film_blended_realtime_safe * Fix real-time safety: eliminate allocations in gated/blended activation paths - Use Eigen::Ref in FiLM::Process and Conv1x1::process_ to accept block expressions without creating temporary matrices - Add pre-allocated buffers in GatingActivation and BlendingActivation to avoid allocating MatrixXf objects in processing loops * Fix LayerArray head buffer size mismatch when head_1x1 is active When head_1x1 was active with out_channels != bottleneck, _head_inputs and _head_rechannel were incorrectly sized using bottleneck instead of head_1x1.out_channels, causing an Eigen matrix dimension mismatch. Added _head_output_size member to _LayerArray that correctly computes the head output size (head_1x1.out_channels if active, else bottleneck). Updated weight generator to match. * Remove unused private variable
1 parent e462709 commit be317cf

21 files changed

Lines changed: 4396 additions & 429 deletions

NAM/dsp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ Eigen::MatrixXf nam::Conv1x1::process(const Eigen::MatrixXf& input, const int nu
417417
return result;
418418
}
419419

420-
void nam::Conv1x1::process_(const Eigen::MatrixXf& input, const int num_frames)
420+
void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, const int num_frames)
421421
{
422422
assert(num_frames <= _output.cols());
423423

NAM/dsp.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ class Conv1x1
207207
Eigen::MatrixXf process(const Eigen::MatrixXf& input) const { return process(input, (int)input.cols()); };
208208
Eigen::MatrixXf process(const Eigen::MatrixXf& input, const int num_frames) const;
209209
// Store output to pre-allocated _output; access with GetOutput()
210-
void process_(const Eigen::MatrixXf& input, const int num_frames);
210+
// Uses Eigen::Ref to accept matrices and block expressions without creating temporaries (real-time safe)
211+
void process_(const Eigen::Ref<const Eigen::MatrixXf>& input, const int num_frames);
211212

212213
long get_out_channels() const { return this->_weight.rows(); };
213214
long get_in_channels() const { return this->_weight.cols(); };

NAM/film.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ class FiLM
4747
// :param input: (input_dim x num_frames)
4848
// :param condition: (condition_dim x num_frames)
4949
// Writes (input_dim x num_frames) into internal output buffer; access via GetOutput().
50-
void Process(const Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, const int num_frames)
50+
// Uses Eigen::Ref to accept matrices and block expressions without creating temporaries (real-time safe)
51+
void Process(const Eigen::Ref<const Eigen::MatrixXf>& input, const Eigen::Ref<const Eigen::MatrixXf>& condition,
52+
const int num_frames)
5153
{
5254
assert(get_input_dim() == input.rows());
5355
assert(get_condition_dim() == condition.rows());
@@ -72,7 +74,9 @@ class FiLM
7274
}
7375

7476
// in-place
75-
void Process_(Eigen::MatrixXf& input, const Eigen::MatrixXf& condition, const int num_frames)
77+
// Uses Eigen::Ref to accept matrices and block expressions without creating temporaries (real-time safe)
78+
void Process_(Eigen::Ref<Eigen::MatrixXf> input, const Eigen::Ref<const Eigen::MatrixXf>& condition,
79+
const int num_frames)
7680
{
7781
Process(input, condition, num_frames);
7882
input.leftCols(num_frames).noalias() = _output.leftCols(num_frames);

NAM/gating_activations.h

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@ class GatingActivation
4242
{
4343
throw std::invalid_argument("GatingActivation: number of input channels must be positive");
4444
}
45-
// Initialize input buffer with correct size
45+
// Initialize buffers with correct size
4646
// Note: current code copies column-by-column so we only need (num_channels, 1)
4747
input_buffer.resize(num_channels, 1);
48+
gating_buffer.resize(num_channels, 1);
4849
}
4950

5051
~GatingActivation() = default;
@@ -64,23 +65,20 @@ class GatingActivation
6465
assert(output.cols() == input.cols());
6566

6667
// Process column-by-column to ensure memory contiguity (important for column-major matrices)
68+
// Uses pre-allocated buffers to avoid allocations in the loop (real-time safe)
6769
const int num_samples = input.cols();
6870
for (int i = 0; i < num_samples; i++)
6971
{
70-
// Store pre-activation input values in buffer to avoid overwriting issues
72+
// Copy to pre-allocated buffers and apply activations in-place
7173
input_buffer = input.block(0, i, num_channels, 1);
74+
input_activation->apply(input_buffer);
7275

73-
// Apply activation to input channels
74-
Eigen::MatrixXf input_block = input.block(0, i, num_channels, 1);
75-
input_activation->apply(input_block);
76-
77-
// Apply activation to gating channels
78-
Eigen::MatrixXf gating_block = input.block(num_channels, i, num_channels, 1);
79-
gating_activation->apply(gating_block);
76+
gating_buffer = input.block(num_channels, i, num_channels, 1);
77+
gating_activation->apply(gating_buffer);
8078

8179
// Element-wise multiplication and store result
8280
// For wavenet compatibility, we assume one-to-one mapping
83-
output.block(0, i, num_channels, 1) = input_block.array() * gating_block.array();
81+
output.block(0, i, num_channels, 1) = input_buffer.array() * gating_buffer.array();
8482
}
8583
}
8684

@@ -99,6 +97,7 @@ class GatingActivation
9997
activations::Activation::Ptr gating_activation;
10098
int num_channels;
10199
Eigen::MatrixXf input_buffer;
100+
Eigen::MatrixXf gating_buffer;
102101
};
103102

104103
class BlendingActivation
@@ -118,9 +117,11 @@ class BlendingActivation
118117
{
119118
assert(num_channels > 0);
120119

121-
// Initialize input buffer with correct size
120+
// Initialize buffers with correct size
122121
// Note: current code copies column-by-column so we only need (num_channels, 1)
122+
pre_activation_buffer.resize(num_channels, 1);
123123
input_buffer.resize(num_channels, 1);
124+
blend_buffer.resize(num_channels, 1);
124125
}
125126

126127
~BlendingActivation() = default;
@@ -140,23 +141,24 @@ class BlendingActivation
140141
assert(output.cols() == input.cols());
141142

142143
// Process column-by-column to ensure memory contiguity
144+
// Uses pre-allocated buffers to avoid allocations in the loop (real-time safe)
143145
const int num_samples = input.cols();
144146
for (int i = 0; i < num_samples; i++)
145147
{
146148
// Store pre-activation input values in buffer
147-
input_buffer = input.block(0, i, num_channels, 1);
149+
pre_activation_buffer = input.block(0, i, num_channels, 1);
148150

149-
// Apply activation to input channels
150-
Eigen::MatrixXf input_block = input.block(0, i, num_channels, 1);
151-
input_activation->apply(input_block);
151+
// Copy to pre-allocated buffer and apply activation to input channels
152+
input_buffer = input.block(0, i, num_channels, 1);
153+
input_activation->apply(input_buffer);
152154

153-
// Apply activation to blend channels to compute alpha
154-
Eigen::MatrixXf blend_block = input.block(num_channels, i, num_channels, 1);
155-
blending_activation->apply(blend_block);
155+
// Copy to pre-allocated buffer and apply activation to blend channels to compute alpha
156+
blend_buffer = input.block(num_channels, i, num_channels, 1);
157+
blending_activation->apply(blend_buffer);
156158

157159
// Weighted blending: alpha * activated_input + (1 - alpha) * pre_activation_input
158160
output.block(0, i, num_channels, 1) =
159-
blend_block.array() * input_block.array() + (1.0f - blend_block.array()) * input_buffer.array();
161+
blend_buffer.array() * input_buffer.array() + (1.0f - blend_buffer.array()) * pre_activation_buffer.array();
160162
}
161163
}
162164

@@ -174,7 +176,9 @@ class BlendingActivation
174176
activations::Activation::Ptr input_activation;
175177
activations::Activation::Ptr blending_activation;
176178
int num_channels;
179+
Eigen::MatrixXf pre_activation_buffer;
177180
Eigen::MatrixXf input_buffer;
181+
Eigen::MatrixXf blend_buffer;
178182
};
179183

180184

NAM/wavenet.cpp

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ void nam::wavenet::_Layer::SetMaxBufferSize(const int maxBufferSize)
4747
this->_activation_pre_film->SetMaxBufferSize(maxBufferSize);
4848
if (this->_activation_post_film)
4949
this->_activation_post_film->SetMaxBufferSize(maxBufferSize);
50-
if (this->_gating_activation_post_film)
51-
this->_gating_activation_post_film->SetMaxBufferSize(maxBufferSize);
5250
if (this->_1x1_post_film)
5351
this->_1x1_post_film->SetMaxBufferSize(maxBufferSize);
5452
if (this->_head1x1_post_film)
@@ -77,8 +75,6 @@ void nam::wavenet::_Layer::set_weights_(std::vector<float>::iterator& weights)
7775
this->_activation_pre_film->set_weights_(weights);
7876
if (this->_activation_post_film)
7977
this->_activation_post_film->set_weights_(weights);
80-
if (this->_gating_activation_post_film)
81-
this->_gating_activation_post_film->set_weights_(weights);
8278
if (this->_1x1_post_film)
8379
this->_1x1_post_film->set_weights_(weights);
8480
if (this->_head1x1_post_film)
@@ -150,12 +146,12 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
150146
auto input_block = this->_z.leftCols(num_frames);
151147
auto output_block = this->_z.topRows(bottleneck).leftCols(num_frames);
152148
this->_gating_activation->apply(input_block, output_block);
153-
if (this->_gating_activation_post_film)
149+
if (this->_activation_post_film)
154150
{
155151
// Use Process() for blocks and copy result back
156-
this->_gating_activation_post_film->Process(this->_z.topRows(bottleneck), condition, num_frames);
152+
this->_activation_post_film->Process(this->_z.topRows(bottleneck), condition, num_frames);
157153
this->_z.topRows(bottleneck).leftCols(num_frames).noalias() =
158-
this->_gating_activation_post_film->GetOutput().leftCols(num_frames);
154+
this->_activation_post_film->GetOutput().leftCols(num_frames);
159155
}
160156
_1x1.process_(this->_z.topRows(bottleneck), num_frames);
161157
}
@@ -219,23 +215,23 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma
219215
nam::wavenet::_LayerArray::_LayerArray(
220216
const int input_size, const int condition_size, const int head_size, const int channels, const int bottleneck,
221217
const int kernel_size, const std::vector<int>& dilations, const activations::ActivationConfig& activation_config,
222-
const GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_1x1,
223-
const Head1x1Params& head1x1_params, const std::string& secondary_activation, const _FiLMParams& conv_pre_film_params,
218+
const GatingMode gating_mode, const bool head_bias, const int groups_input, const int groups_input_mixin,
219+
const int groups_1x1, const Head1x1Params& head1x1_params,
220+
const activations::ActivationConfig& secondary_activation_config, const _FiLMParams& conv_pre_film_params,
224221
const _FiLMParams& conv_post_film_params, const _FiLMParams& input_mixin_pre_film_params,
225222
const _FiLMParams& input_mixin_post_film_params, const _FiLMParams& activation_pre_film_params,
226-
const _FiLMParams& activation_post_film_params, const _FiLMParams& gating_activation_post_film_params,
227-
const _FiLMParams& _1x1_post_film_params, const _FiLMParams& head1x1_post_film_params)
223+
const _FiLMParams& activation_post_film_params, const _FiLMParams& _1x1_post_film_params,
224+
const _FiLMParams& head1x1_post_film_params)
228225
: _rechannel(input_size, channels, false)
229-
, _head_rechannel(bottleneck, head_size, head_bias)
230-
, _bottleneck(bottleneck)
226+
, _head_rechannel(head1x1_params.active ? head1x1_params.out_channels : bottleneck, head_size, head_bias)
227+
, _head_output_size(head1x1_params.active ? head1x1_params.out_channels : bottleneck)
231228
{
232229
for (size_t i = 0; i < dilations.size(); i++)
233-
this->_layers.push_back(_Layer(condition_size, channels, bottleneck, kernel_size, dilations[i], activation_config,
234-
gating_mode, groups_input, groups_1x1, head1x1_params, secondary_activation,
235-
conv_pre_film_params, conv_post_film_params, input_mixin_pre_film_params,
236-
input_mixin_post_film_params, activation_pre_film_params,
237-
activation_post_film_params, gating_activation_post_film_params,
238-
_1x1_post_film_params, head1x1_post_film_params));
230+
this->_layers.push_back(
231+
_Layer(condition_size, channels, bottleneck, kernel_size, dilations[i], activation_config, gating_mode,
232+
groups_input, groups_input_mixin, groups_1x1, head1x1_params, secondary_activation_config,
233+
conv_pre_film_params, conv_post_film_params, input_mixin_pre_film_params, input_mixin_post_film_params,
234+
activation_pre_film_params, activation_post_film_params, _1x1_post_film_params, head1x1_post_film_params));
239235
}
240236

241237
void nam::wavenet::_LayerArray::SetMaxBufferSize(const int maxBufferSize)
@@ -249,7 +245,8 @@ void nam::wavenet::_LayerArray::SetMaxBufferSize(const int maxBufferSize)
249245
// Pre-allocate output buffers
250246
const long channels = this->_get_channels();
251247
this->_layer_outputs.resize(channels, maxBufferSize);
252-
this->_head_inputs.resize(this->_bottleneck, maxBufferSize);
248+
// _head_inputs size matches actual head output: head1x1.out_channels if active, else bottleneck
249+
this->_head_inputs.resize(this->_head_output_size, maxBufferSize);
253250
}
254251

255252

@@ -386,12 +383,12 @@ nam::wavenet::WaveNet::WaveNet(const int in_channels,
386383
layer_array_params[i].input_size, layer_array_params[i].condition_size, layer_array_params[i].head_size,
387384
layer_array_params[i].channels, layer_array_params[i].bottleneck, layer_array_params[i].kernel_size,
388385
layer_array_params[i].dilations, layer_array_params[i].activation_config, layer_array_params[i].gating_mode,
389-
layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_1x1,
390-
layer_array_params[i].head1x1_params, layer_array_params[i].secondary_activation,
391-
layer_array_params[i].conv_pre_film_params, layer_array_params[i].conv_post_film_params,
392-
layer_array_params[i].input_mixin_pre_film_params, layer_array_params[i].input_mixin_post_film_params,
393-
layer_array_params[i].activation_pre_film_params, layer_array_params[i].activation_post_film_params,
394-
layer_array_params[i].gating_activation_post_film_params, layer_array_params[i]._1x1_post_film_params,
386+
layer_array_params[i].head_bias, layer_array_params[i].groups_input, layer_array_params[i].groups_input_mixin,
387+
layer_array_params[i].groups_1x1, layer_array_params[i].head1x1_params,
388+
layer_array_params[i].secondary_activation_config, layer_array_params[i].conv_pre_film_params,
389+
layer_array_params[i].conv_post_film_params, layer_array_params[i].input_mixin_pre_film_params,
390+
layer_array_params[i].input_mixin_post_film_params, layer_array_params[i].activation_pre_film_params,
391+
layer_array_params[i].activation_post_film_params, layer_array_params[i]._1x1_post_film_params,
395392
layer_array_params[i].head1x1_post_film_params));
396393
if (i > 0)
397394
if (layer_array_params[i].channels != layer_array_params[i - 1].head_size)
@@ -583,7 +580,8 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
583580
{
584581
nlohmann::json layer_config = config["layers"][i];
585582

586-
const int groups = layer_config.value("groups", 1); // defaults to 1
583+
const int groups = layer_config.value("groups_input", 1); // defaults to 1
584+
const int groups_input_mixin = layer_config.value("groups_input_mixin", 1); // defaults to 1
587585
const int groups_1x1 = layer_config.value("groups_1x1", 1); // defaults to 1
588586

589587
const int channels = layer_config["channels"];
@@ -599,25 +597,25 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
599597
activations::ActivationConfig::from_json(layer_config["activation"]);
600598
// Parse gating mode - support both old "gated" boolean and new "gating_mode" string
601599
GatingMode gating_mode = GatingMode::NONE;
602-
std::string secondary_activation;
600+
activations::ActivationConfig secondary_activation_config;
603601

604602
if (layer_config.find("gating_mode") != layer_config.end())
605603
{
606604
std::string gating_mode_str = layer_config["gating_mode"].get<std::string>();
607605
if (gating_mode_str == "gated")
608606
{
609607
gating_mode = GatingMode::GATED;
610-
secondary_activation = layer_config["secondary_activation"].get<std::string>();
608+
secondary_activation_config = activations::ActivationConfig::from_json(layer_config["secondary_activation"]);
611609
}
612610
else if (gating_mode_str == "blended")
613611
{
614612
gating_mode = GatingMode::BLENDED;
615-
secondary_activation = layer_config["secondary_activation"].get<std::string>();
613+
secondary_activation_config = activations::ActivationConfig::from_json(layer_config["secondary_activation"]);
616614
}
617615
else if (gating_mode_str == "none")
618616
{
619617
gating_mode = GatingMode::NONE;
620-
secondary_activation.clear();
618+
// Leave secondary_activation_config with empty type
621619
}
622620
else
623621
throw std::runtime_error("Invalid gating_mode: " + gating_mode_str);
@@ -629,12 +627,9 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
629627
gating_mode = gated ? GatingMode::GATED : GatingMode::NONE;
630628
if (gated)
631629
{
632-
secondary_activation = "Sigmoid";
633-
}
634-
else
635-
{
636-
secondary_activation.clear();
630+
secondary_activation_config = activations::ActivationConfig::simple(activations::ActivationType::Sigmoid);
637631
}
632+
// else: leave secondary_activation_config uninitialized
638633
}
639634
else
640635
{
@@ -644,9 +639,16 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
644639
const bool head_bias = layer_config["head_bias"];
645640

646641
// Parse head1x1 parameters
647-
bool head1x1_active = layer_config.value("head1x1_active", false);
648-
int head1x1_out_channels = layer_config.value("head1x1_out_channels", channels);
649-
int head1x1_groups = layer_config.value("head1x1_groups", 1);
642+
bool head1x1_active = false;
643+
int head1x1_out_channels = channels;
644+
int head1x1_groups = 1;
645+
if (layer_config.find("head_1x1") != layer_config.end())
646+
{
647+
const auto& head1x1_config = layer_config["head_1x1"];
648+
head1x1_active = head1x1_config["active"];
649+
head1x1_out_channels = head1x1_config["out_channels"];
650+
head1x1_groups = head1x1_config["groups"];
651+
}
650652
nam::wavenet::Head1x1Params head1x1_params(head1x1_active, head1x1_out_channels, head1x1_groups);
651653

652654
// Helper function to parse FiLM parameters
@@ -668,16 +670,14 @@ std::unique_ptr<nam::DSP> nam::wavenet::Factory(const nlohmann::json& config, st
668670
nam::wavenet::_FiLMParams input_mixin_post_film_params = parse_film_params("input_mixin_post_film");
669671
nam::wavenet::_FiLMParams activation_pre_film_params = parse_film_params("activation_pre_film");
670672
nam::wavenet::_FiLMParams activation_post_film_params = parse_film_params("activation_post_film");
671-
nam::wavenet::_FiLMParams gating_activation_post_film_params = parse_film_params("gating_activation_post_film");
672673
nam::wavenet::_FiLMParams _1x1_post_film_params = parse_film_params("1x1_post_film");
673674
nam::wavenet::_FiLMParams head1x1_post_film_params = parse_film_params("head1x1_post_film");
674675

675676
layer_array_params.push_back(nam::wavenet::LayerArrayParams(
676677
input_size, condition_size, head_size, channels, bottleneck, kernel_size, dilations, activation_config,
677-
gating_mode, head_bias, groups, groups_1x1, head1x1_params, secondary_activation, conv_pre_film_params,
678-
conv_post_film_params, input_mixin_pre_film_params, input_mixin_post_film_params, activation_pre_film_params,
679-
activation_post_film_params, gating_activation_post_film_params, _1x1_post_film_params,
680-
head1x1_post_film_params));
678+
gating_mode, head_bias, groups, groups_input_mixin, groups_1x1, head1x1_params, secondary_activation_config,
679+
conv_pre_film_params, conv_post_film_params, input_mixin_pre_film_params, input_mixin_post_film_params,
680+
activation_pre_film_params, activation_post_film_params, _1x1_post_film_params, head1x1_post_film_params));
681681
}
682682
const bool with_head = !config["head"].is_null();
683683
const float head_scale = config["head_scale"];

0 commit comments

Comments
 (0)