Skip to content

Commit 0129678

Browse files
Copilotshauneccles
andcommitted
Phase 5 complete: Implement Resampler class with nanobind
Co-authored-by: shauneccles <21007065+shauneccles@users.noreply.github.com>
1 parent 5313545 commit 0129678

1 file changed

Lines changed: 191 additions & 0 deletions

File tree

src/samplerate_nb.cpp

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,147 @@ nb::ndarray<nb::numpy, float> resample(
214214
}
215215
}
216216

217+
class Resampler {
218+
private:
219+
SRC_STATE *_state = nullptr;
220+
221+
public:
222+
int _converter_type = 0;
223+
int _channels = 0;
224+
225+
public:
226+
Resampler(const nb::object &converter_type, int channels)
227+
: _converter_type(get_converter_type(converter_type)),
228+
_channels(channels) {
229+
int _err_num = 0;
230+
_state = src_new(_converter_type, _channels, &_err_num);
231+
error_handler(_err_num);
232+
}
233+
234+
// copy constructor
235+
Resampler(const Resampler &r)
236+
: _converter_type(r._converter_type), _channels(r._channels) {
237+
int _err_num = 0;
238+
_state = src_clone(r._state, &_err_num);
239+
error_handler(_err_num);
240+
}
241+
242+
// move constructor
243+
Resampler(Resampler &&r)
244+
: _state(r._state),
245+
_converter_type(r._converter_type),
246+
_channels(r._channels) {
247+
r._state = nullptr;
248+
r._converter_type = 0;
249+
r._channels = 0;
250+
}
251+
252+
~Resampler() { src_delete(_state); } // src_delete handles nullptr case
253+
254+
nb::ndarray<nb::numpy, float> process(
255+
const nb::ndarray<nb::numpy, const float, nb::c_contig> &input,
256+
double sr_ratio, bool end_of_input) {
257+
// Get array dimensions
258+
size_t ndim = input.ndim();
259+
size_t num_frames = input.shape(0);
260+
261+
// set the number of channels
262+
int channels = 1;
263+
if (ndim == 2)
264+
channels = input.shape(1);
265+
else if (ndim > 2)
266+
throw std::domain_error("Input array should have at most 2 dimensions");
267+
268+
if (channels != _channels || channels == 0)
269+
throw std::domain_error("Invalid number of channels in input data.");
270+
271+
// Add a "fudge factor" to the size. This is because the actual number of
272+
// output samples generated on the last call when input is terminated can
273+
// be more than the expected number of output samples during mid-stream
274+
// steady-state processing. (Also, when the stream is started, the number
275+
// of output samples generated will generally be zero or otherwise less
276+
// than the number of samples in mid-stream processing.)
277+
const auto new_size =
278+
static_cast<size_t>(std::ceil(num_frames * sr_ratio))
279+
+ END_OF_INPUT_EXTRA_OUTPUT_FRAMES;
280+
281+
// Allocate output array
282+
size_t total_elements = new_size * channels;
283+
float* output_data = new float[total_elements];
284+
285+
// Create capsule for memory management
286+
nb::capsule owner(output_data, [](void* p) noexcept {
287+
delete[] static_cast<float*>(p);
288+
});
289+
290+
// libsamplerate struct
291+
SRC_DATA src_data = {
292+
const_cast<float *>(input.data()), // data_in
293+
output_data, // data_out
294+
static_cast<long>(num_frames), // input_frames
295+
long(new_size), // output_frames
296+
0, // input_frames_used, filled by libsamplerate
297+
0, // output_frames_gen, filled by libsamplerate
298+
end_of_input, // end_of_input
299+
sr_ratio // src_ratio, sampling rate conversion ratio
300+
};
301+
302+
// Release GIL for the entire resampling operation
303+
int err_code;
304+
long output_frames_gen;
305+
{
306+
nb::gil_scoped_release release;
307+
err_code = src_process(_state, &src_data);
308+
output_frames_gen = src_data.output_frames_gen;
309+
}
310+
error_handler(err_code);
311+
312+
// Handle unexpected output size
313+
if ((size_t)output_frames_gen > new_size) {
314+
// This means our fudge factor is too small.
315+
throw std::runtime_error("Generated more output samples than expected!");
316+
}
317+
318+
// Create output ndarray with proper shape and stride
319+
size_t output_shape[2];
320+
int64_t output_stride[2];
321+
322+
if (ndim == 2) {
323+
output_shape[0] = output_frames_gen;
324+
output_shape[1] = channels;
325+
output_stride[0] = channels * sizeof(float);
326+
output_stride[1] = sizeof(float);
327+
328+
return nb::ndarray<nb::numpy, float>(
329+
output_data,
330+
2,
331+
output_shape,
332+
owner,
333+
output_stride
334+
);
335+
} else {
336+
output_shape[0] = output_frames_gen;
337+
output_stride[0] = sizeof(float);
338+
339+
return nb::ndarray<nb::numpy, float>(
340+
output_data,
341+
1,
342+
output_shape,
343+
owner,
344+
output_stride
345+
);
346+
}
347+
}
348+
349+
void set_ratio(double new_ratio) {
350+
error_handler(src_set_ratio(_state, new_ratio));
351+
}
352+
353+
void reset() { error_handler(src_reset(_state)); }
354+
355+
Resampler clone() const { return Resampler(*this); }
356+
};
357+
217358
} // namespace samplerate
218359

219360
namespace sr = samplerate;
@@ -296,8 +437,58 @@ NB_MODULE(samplerate, m) {
296437
"input"_a, "ratio"_a, "converter_type"_a = "sinc_best",
297438
"verbose"_a = false);
298439

440+
nb::class_<sr::Resampler>(m_converters, "Resampler", R"mydelimiter(
441+
Resampler.
442+
443+
Parameters
444+
----------
445+
converter_type : ConverterType, str, or int
446+
Sample rate converter (default: `sinc_best`).
447+
num_channels : int
448+
Number of channels.
449+
)mydelimiter")
450+
.def(nb::init<const nb::object &, int>(),
451+
"converter_type"_a = "sinc_best", "channels"_a = 1)
452+
.def(nb::init<sr::Resampler>())
453+
.def("process", &sr::Resampler::process, R"mydelimiter(
454+
Resample the signal in `input_data`.
455+
456+
Parameters
457+
----------
458+
input_data : ndarray
459+
Input data.
460+
Input data with one or more channels is represented as a 2D array of shape
461+
(`num_frames`, `num_channels`).
462+
A single channel can be provided as a 1D array of `num_frames` length.
463+
For use with `libsamplerate`, `input_data` is converted to 32-bit float and
464+
C (row-major) memory order.
465+
ratio : float
466+
Conversion ratio = output sample rate / input sample rate.
467+
end_of_input : int
468+
Set to `True` if no more data is available, or to `False` otherwise.
469+
verbose : bool
470+
If `True`, print additional information about the conversion.
471+
472+
Returns
473+
-------
474+
output_data : ndarray
475+
Resampled input data.
476+
)mydelimiter",
477+
"input"_a, "ratio"_a, "end_of_input"_a = false)
478+
.def("reset", &sr::Resampler::reset, "Reset internal state.")
479+
.def("set_ratio", &sr::Resampler::set_ratio,
480+
"Set a new conversion ratio immediately.")
481+
.def("clone", &sr::Resampler::clone,
482+
"Creates a copy of the resampler object with the same internal "
483+
"state.")
484+
.def_ro("converter_type", &sr::Resampler::_converter_type,
485+
"Converter type.")
486+
.def_ro("channels", &sr::Resampler::_channels,
487+
"Number of channels.");
488+
299489
// Convenience imports
300490
m.attr("ResamplingError") = m_exceptions.attr("ResamplingError");
301491
m.attr("resample") = m_converters.attr("resample");
492+
m.attr("Resampler") = m_converters.attr("Resampler");
302493
m.attr("ConverterType") = m_converters.attr("ConverterType");
303494
}

0 commit comments

Comments
 (0)