44#include < unordered_set>
55
66#include " dsp.h"
7+ #include " registry.h"
78#include " json.hpp"
89#include " lstm.h"
910#include " convnet.h"
1011#include " wavenet.h"
12+ #include " get_dsp.h"
1113
1214namespace nam
1315{
@@ -102,12 +104,7 @@ std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, dspDat
102104 returnedConfig.config = j[" config" ];
103105 returnedConfig.metadata = j[" metadata" ];
104106 returnedConfig.weights = weights;
105- if (j.find (" sample_rate" ) != j.end ())
106- returnedConfig.expected_sample_rate = j[" sample_rate" ];
107- else
108- {
109- returnedConfig.expected_sample_rate = -1.0 ;
110- }
107+ returnedConfig.expected_sample_rate = nam::get_sample_rate_from_nam_file (j);
111108
112109 /* Copy to a new dsp_config object for get_dsp below,
113110 since not sure if weights actually get modified as being non-const references on some
@@ -152,47 +149,9 @@ std::unique_ptr<DSP> get_dsp(dspData& conf)
152149 }
153150 const double expectedSampleRate = conf.expected_sample_rate ;
154151
155- std::unique_ptr<DSP > out = nullptr ;
156- if (architecture == " Linear" )
157- {
158- const int receptive_field = config[" receptive_field" ];
159- const bool _bias = config[" bias" ];
160- out = std::make_unique<Linear>(receptive_field, _bias, weights, expectedSampleRate);
161- }
162- else if (architecture == " ConvNet" )
163- {
164- const int channels = config[" channels" ];
165- const bool batchnorm = config[" batchnorm" ];
166- std::vector<int > dilations = config[" dilations" ];
167- const std::string activation = config[" activation" ];
168- out = std::make_unique<convnet::ConvNet>(channels, dilations, batchnorm, activation, weights, expectedSampleRate);
169- }
170- else if (architecture == " LSTM" )
171- {
172- const int num_layers = config[" num_layers" ];
173- const int input_size = config[" input_size" ];
174- const int hidden_size = config[" hidden_size" ];
175- out = std::make_unique<lstm::LSTM >(num_layers, input_size, hidden_size, weights, expectedSampleRate);
176- }
177- else if (architecture == " WaveNet" )
178- {
179- std::vector<wavenet::LayerArrayParams> layer_array_params;
180- for (size_t i = 0 ; i < config[" layers" ].size (); i++)
181- {
182- nlohmann::json layer_config = config[" layers" ][i];
183- layer_array_params.push_back (
184- wavenet::LayerArrayParams (layer_config[" input_size" ], layer_config[" condition_size" ], layer_config[" head_size" ],
185- layer_config[" channels" ], layer_config[" kernel_size" ], layer_config[" dilations" ],
186- layer_config[" activation" ], layer_config[" gated" ], layer_config[" head_bias" ]));
187- }
188- const bool with_head = !config[" head" ].is_null ();
189- const float head_scale = config[" head_scale" ];
190- out = std::make_unique<wavenet::WaveNet>(layer_array_params, head_scale, with_head, weights, expectedSampleRate);
191- }
192- else
193- {
194- throw std::runtime_error (" Unrecognized architecture" );
195- }
152+ // Initialize using registry-based factory
153+ std::unique_ptr<DSP > out =
154+ nam::factory::FactoryRegistry::instance ().create (architecture, config, weights, expectedSampleRate);
196155 if (loudness.have )
197156 {
198157 out->SetLoudness (loudness.value );
@@ -212,4 +171,13 @@ std::unique_ptr<DSP> get_dsp(dspData& conf)
212171
213172 return out;
214173}
174+
175+ double get_sample_rate_from_nam_file (const nlohmann::json& j)
176+ {
177+ if (j.find (" sample_rate" ) != j.end ())
178+ return j[" sample_rate" ];
179+ else
180+ return -1.0 ;
181+ }
182+
215183}; // namespace nam
0 commit comments