Skip to content

Commit 44d54bb

Browse files
authored
[FEATURE] Ability to register new factories into get_dsp() (#156)
* Refactoring to get_dsp() factory with registry of implementations * LSTM and WaveNet instantiating from new factory registry * ConvNet * Linear
1 parent 8496fb9 commit 44d54bb

12 files changed

Lines changed: 178 additions & 51 deletions

File tree

NAM/convnet.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <unordered_set>
99

1010
#include "dsp.h"
11+
#include "registry.h"
1112
#include "convnet.h"
1213

1314
nam::convnet::BatchNorm::BatchNorm(const int dim, std::vector<float>::iterator& weights)
@@ -184,3 +185,20 @@ void nam::convnet::ConvNet::_rewind_buffers_()
184185
// Now we can do the rest of the rewind
185186
this->Buffer::_rewind_buffers_();
186187
}
188+
189+
// Factory
190+
std::unique_ptr<nam::DSP> nam::convnet::Factory(const nlohmann::json& config, std::vector<float>& weights,
191+
const double expectedSampleRate)
192+
{
193+
const int channels = config["channels"];
194+
const std::vector<int> dilations = config["dilations"];
195+
const bool batchnorm = config["batchnorm"];
196+
const std::string activation = config["activation"];
197+
return std::make_unique<nam::convnet::ConvNet>(
198+
channels, dilations, batchnorm, activation, weights, expectedSampleRate);
199+
}
200+
201+
namespace
202+
{
203+
static nam::factory::Helper _register_ConvNet("ConvNet", nam::convnet::Factory);
204+
}

NAM/convnet.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,10 @@ class ConvNet : public Buffer
8686
int mPrewarmSamples = 0; // Pre-compute during initialization
8787
int PrewarmSamples() override { return mPrewarmSamples; };
8888
};
89+
90+
// Factory
91+
std::unique_ptr<DSP> Factory(const nlohmann::json& config, std::vector<float>& weights,
92+
const double expectedSampleRate);
93+
8994
}; // namespace convnet
9095
}; // namespace nam

NAM/dsp.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <unordered_set>
88

99
#include "dsp.h"
10+
#include "registry.h"
1011

1112
#define tanh_impl_ std::tanh
1213
// #define tanh_impl_ fast_tanh_
@@ -192,6 +193,15 @@ void nam::Linear::process(NAM_SAMPLE* input, NAM_SAMPLE* output, const int num_f
192193
nam::Buffer::_advance_input_buffer_(num_frames);
193194
}
194195

196+
// Factory
197+
std::unique_ptr<nam::DSP> nam::linear::Factory(const nlohmann::json& config, std::vector<float>& weights,
198+
const double expectedSampleRate)
199+
{
200+
const int receptive_field = config["receptive_field"];
201+
const bool bias = config["bias"];
202+
return std::make_unique<nam::Linear>(receptive_field, bias, weights, expectedSampleRate);
203+
}
204+
195205
// NN modules =================================================================
196206

197207
void nam::Conv1D::set_weights_(std::vector<float>::iterator& weights)

NAM/dsp.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,12 @@ class Linear : public Buffer
164164
float _bias;
165165
};
166166

167+
namespace linear
168+
{
169+
std::unique_ptr<DSP> Factory(const nlohmann::json& config, std::vector<float>& weights,
170+
const double expectedSampleRate);
171+
} // namespace linear
172+
167173
// NN modules =================================================================
168174

169175
// TODO conv could take care of its own ring buffer.

NAM/get_dsp.cpp

Lines changed: 15 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
#include <unordered_set>
55

66
#include "dsp.h"
7+
#include "registry.h"
78
#include "json.hpp"
89
#include "lstm.h"
910
#include "convnet.h"
1011
#include "wavenet.h"
12+
#include "get_dsp.h"
1113

1214
namespace nam
1315
{
@@ -102,12 +104,7 @@ std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, dspDat
102104
returnedConfig.config = j["config"];
103105
returnedConfig.metadata = j["metadata"];
104106
returnedConfig.weights = weights;
105-
if (j.find("sample_rate") != j.end())
106-
returnedConfig.expected_sample_rate = j["sample_rate"];
107-
else
108-
{
109-
returnedConfig.expected_sample_rate = -1.0;
110-
}
107+
returnedConfig.expected_sample_rate = nam::get_sample_rate_from_nam_file(j);
111108

112109
/*Copy to a new dsp_config object for get_dsp below,
113110
since not sure if weights actually get modified as being non-const references on some
@@ -152,47 +149,9 @@ std::unique_ptr<DSP> get_dsp(dspData& conf)
152149
}
153150
const double expectedSampleRate = conf.expected_sample_rate;
154151

155-
std::unique_ptr<DSP> out = nullptr;
156-
if (architecture == "Linear")
157-
{
158-
const int receptive_field = config["receptive_field"];
159-
const bool _bias = config["bias"];
160-
out = std::make_unique<Linear>(receptive_field, _bias, weights, expectedSampleRate);
161-
}
162-
else if (architecture == "ConvNet")
163-
{
164-
const int channels = config["channels"];
165-
const bool batchnorm = config["batchnorm"];
166-
std::vector<int> dilations = config["dilations"];
167-
const std::string activation = config["activation"];
168-
out = std::make_unique<convnet::ConvNet>(channels, dilations, batchnorm, activation, weights, expectedSampleRate);
169-
}
170-
else if (architecture == "LSTM")
171-
{
172-
const int num_layers = config["num_layers"];
173-
const int input_size = config["input_size"];
174-
const int hidden_size = config["hidden_size"];
175-
out = std::make_unique<lstm::LSTM>(num_layers, input_size, hidden_size, weights, expectedSampleRate);
176-
}
177-
else if (architecture == "WaveNet")
178-
{
179-
std::vector<wavenet::LayerArrayParams> layer_array_params;
180-
for (size_t i = 0; i < config["layers"].size(); i++)
181-
{
182-
nlohmann::json layer_config = config["layers"][i];
183-
layer_array_params.push_back(
184-
wavenet::LayerArrayParams(layer_config["input_size"], layer_config["condition_size"], layer_config["head_size"],
185-
layer_config["channels"], layer_config["kernel_size"], layer_config["dilations"],
186-
layer_config["activation"], layer_config["gated"], layer_config["head_bias"]));
187-
}
188-
const bool with_head = !config["head"].is_null();
189-
const float head_scale = config["head_scale"];
190-
out = std::make_unique<wavenet::WaveNet>(layer_array_params, head_scale, with_head, weights, expectedSampleRate);
191-
}
192-
else
193-
{
194-
throw std::runtime_error("Unrecognized architecture");
195-
}
152+
// Initialize using registry-based factory
153+
std::unique_ptr<DSP> out =
154+
nam::factory::FactoryRegistry::instance().create(architecture, config, weights, expectedSampleRate);
196155
if (loudness.have)
197156
{
198157
out->SetLoudness(loudness.value);
@@ -212,4 +171,13 @@ std::unique_ptr<DSP> get_dsp(dspData& conf)
212171

213172
return out;
214173
}
174+
175+
double get_sample_rate_from_nam_file(const nlohmann::json& j)
176+
{
177+
if (j.find("sample_rate") != j.end())
178+
return j["sample_rate"];
179+
else
180+
return -1.0;
181+
}
182+
215183
}; // namespace nam

NAM/get_dsp.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#pragma once
2+
13
#include <fstream>
24

35
#include "dsp.h"
@@ -12,4 +14,8 @@ std::unique_ptr<DSP> get_dsp(dspData& conf);
1214

1315
// Get NAM from a provided .nam file path and store its configuration in the provided conf
1416
std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig);
17+
18+
// Get sample rate from a .nam file
19+
// Returns -1 if not known (Really old .nam files)
20+
double get_sample_rate_from_nam_file(const nlohmann::json& j);
1521
}; // namespace nam

NAM/lstm.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#include <algorithm>
22
#include <string>
33
#include <vector>
4+
#include <memory>
45

6+
#include "registry.h"
57
#include "lstm.h"
68

79
nam::lstm::LSTMCell::LSTMCell(const int input_size, const int hidden_size, std::vector<float>::iterator& weights)
@@ -102,3 +104,19 @@ float nam::lstm::LSTM::_process_sample(const float x)
102104
this->_layers[i].process_(this->_layers[i - 1].get_hidden_state());
103105
return this->_head_weight.dot(this->_layers[this->_layers.size() - 1].get_hidden_state()) + this->_head_bias;
104106
}
107+
108+
// Factory to instantiate from nlohmann json
109+
std::unique_ptr<nam::DSP> nam::lstm::Factory(const nlohmann::json& config, std::vector<float>& weights,
110+
const double expectedSampleRate)
111+
{
112+
const int num_layers = config["num_layers"];
113+
const int input_size = config["input_size"];
114+
const int hidden_size = config["hidden_size"];
115+
return std::make_unique<nam::lstm::LSTM>(num_layers, input_size, hidden_size, weights, expectedSampleRate);
116+
}
117+
118+
// Register the factory
119+
namespace
120+
{
121+
static nam::factory::Helper _register_LSTM("LSTM", nam::lstm::Factory);
122+
}

NAM/lstm.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <map>
55
#include <vector>
6+
#include <memory>
67

78
#include <Eigen/Dense>
89

@@ -69,5 +70,10 @@ class LSTM : public DSP
6970
// Since this is assumed to not be a parametric model, its shape should be (1,)
7071
Eigen::VectorXf _input;
7172
};
73+
74+
// Factory to instantiate from nlohmann json
75+
std::unique_ptr<DSP> Factory(const nlohmann::json& config, std::vector<float>& weights,
76+
const double expectedSampleRate);
77+
7278
}; // namespace lstm
7379
}; // namespace nam

NAM/registry.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#pragma once
2+
3+
// Registry for DSP objects
4+
5+
#include <string>
6+
#include <memory>
7+
#include <unordered_map>
8+
#include <functional>
9+
10+
#include "dsp.h"
11+
12+
namespace nam
13+
{
14+
namespace factory
15+
{
16+
// TODO get rid of weights and expectedSampleRate
17+
using FactoryFunction = std::function<std::unique_ptr<DSP>(const nlohmann::json&, std::vector<float>&, const double)>;
18+
19+
// Register factories for instantiating DSP objects
20+
class FactoryRegistry
21+
{
22+
public:
23+
static FactoryRegistry& instance()
24+
{
25+
static FactoryRegistry inst;
26+
return inst;
27+
}
28+
29+
void registerFactory(const std::string& key, FactoryFunction func)
30+
{
31+
// Assert that the key is not already registered
32+
if (factories_.find(key) != factories_.end())
33+
{
34+
throw std::runtime_error("Factory already registered for key: " + key);
35+
}
36+
factories_[key] = func;
37+
}
38+
39+
std::unique_ptr<DSP> create(const std::string& name, const nlohmann::json& config, std::vector<float>& weights,
40+
const double expectedSampleRate) const
41+
{
42+
auto it = factories_.find(name);
43+
if (it != factories_.end())
44+
{
45+
return it->second(config, weights, expectedSampleRate);
46+
}
47+
throw std::runtime_error("Factory not found for name: " + name);
48+
}
49+
50+
private:
51+
std::unordered_map<std::string, FactoryFunction> factories_;
52+
};
53+
54+
// Registration helper. Use this to register your factories.
55+
struct Helper
56+
{
57+
Helper(const std::string& name, FactoryFunction factory)
58+
{
59+
FactoryRegistry::instance().registerFactory(name, std::move(factory));
60+
}
61+
};
62+
} // namespace factory
63+
} // namespace nam

NAM/version.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
#ifndef version_h
2-
#define version_h
1+
#pragma once
32

43
// Make sure this matches NAM version in ../CMakeLists.txt!
54
#define NEURAL_AMP_MODELER_DSP_VERSION_MAJOR 0
65
#define NEURAL_AMP_MODELER_DSP_VERSION_MINOR 3
76
#define NEURAL_AMP_MODELER_DSP_VERSION_PATCH 0
8-
9-
#endif

0 commit comments

Comments
 (0)