Skip to content

Commit 0ce826c

Browse files
Copilotshauneccles
andcommitted
Phase 6 complete: Implement CallbackResampler with nanobind
Co-authored-by: shauneccles <21007065+shauneccles@users.noreply.github.com>
1 parent 0129678 commit 0ce826c

1 file changed

Lines changed: 286 additions & 0 deletions

File tree

src/samplerate_nb.cpp

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,235 @@ class Resampler {
355355
Resampler clone() const { return Resampler(*this); }
356356
};
357357

358+
namespace {
359+
360+
long the_callback_func(void *cb_data, float **data);
361+
362+
} // namespace
363+
364+
class CallbackResampler {
365+
private:
366+
SRC_STATE *_state = nullptr;
367+
callback_t _callback = nullptr;
368+
nb_array_f32 _current_buffer;
369+
size_t _buffer_ndim = 0;
370+
std::string _callback_error_msg = "";
371+
372+
public:
373+
double _ratio = 0.0;
374+
int _converter_type = 0;
375+
size_t _channels = 0;
376+
377+
private:
378+
void _create() {
379+
int _err_num = 0;
380+
_state = src_callback_new(the_callback_func, _converter_type, (int)_channels,
381+
&_err_num, static_cast<void *>(this));
382+
if (_state == nullptr) error_handler(_err_num);
383+
}
384+
385+
void _destroy() {
386+
if (_state != nullptr) {
387+
src_delete(_state);
388+
_state = nullptr;
389+
}
390+
}
391+
392+
public:
393+
CallbackResampler(const callback_t &callback_func, double ratio,
394+
const nb::object &converter_type, size_t channels)
395+
: _callback(callback_func),
396+
_ratio(ratio),
397+
_converter_type(get_converter_type(converter_type)),
398+
_channels(channels) {
399+
_create();
400+
}
401+
402+
// copy constructor
403+
CallbackResampler(const CallbackResampler &r)
404+
: _callback(r._callback),
405+
_ratio(r._ratio),
406+
_converter_type(r._converter_type),
407+
_channels(r._channels) {
408+
int _err_num = 0;
409+
_state = src_clone(r._state, &_err_num);
410+
if (_state == nullptr) error_handler(_err_num);
411+
}
412+
413+
// move constructor
414+
CallbackResampler(CallbackResampler &&r)
415+
: _state(r._state),
416+
_callback(r._callback),
417+
_current_buffer(std::move(r._current_buffer)),
418+
_buffer_ndim(r._buffer_ndim),
419+
_callback_error_msg(std::move(r._callback_error_msg)),
420+
_ratio(r._ratio),
421+
_converter_type(r._converter_type),
422+
_channels(r._channels) {
423+
r._state = nullptr;
424+
r._callback = nullptr;
425+
r._buffer_ndim = 0;
426+
r._ratio = 0.0;
427+
r._converter_type = 0;
428+
r._channels = 0;
429+
}
430+
431+
~CallbackResampler() { _destroy(); }
432+
433+
void set_buffer(const nb_array_f32 &new_buf) { _current_buffer = new_buf; }
434+
nb_array_f32 get_buffer() const { return _current_buffer; }
435+
size_t get_channels() { return _channels; }
436+
void set_callback_error(const std::string &error_msg) {
437+
_callback_error_msg = error_msg;
438+
}
439+
std::string get_callback_error() const { return _callback_error_msg; }
440+
void clear_callback_error() { _callback_error_msg = ""; }
441+
442+
nb_array_f32 callback(void) {
443+
auto input = _callback();
444+
445+
if (input.ndim() > 0 && _buffer_ndim == 0)
446+
_buffer_ndim = input.ndim();
447+
448+
_current_buffer = input;
449+
return input;
450+
}
451+
452+
nb::ndarray<nb::numpy, float> read(size_t frames) {
453+
// Allocate output array
454+
size_t total_elements = frames * _channels;
455+
float* output_data = new float[total_elements];
456+
457+
// Create capsule for memory management
458+
nb::capsule owner(output_data, [](void* p) noexcept {
459+
delete[] static_cast<float*>(p);
460+
});
461+
462+
if (_state == nullptr) _create();
463+
464+
// clear any previous callback error
465+
clear_callback_error();
466+
467+
// read from the callback - note: GIL is managed by the_callback_func
468+
// which acquires it only when calling the Python callback
469+
size_t output_frames_gen = 0;
470+
int err_code = 0;
471+
{
472+
nb::gil_scoped_release release;
473+
output_frames_gen = src_callback_read(_state, _ratio, (long)frames,
474+
output_data);
475+
// Get error code while GIL is released
476+
if (output_frames_gen == 0) {
477+
err_code = src_error(_state);
478+
}
479+
}
480+
481+
// check if callback had an error
482+
std::string callback_error = get_callback_error();
483+
if (!callback_error.empty()) {
484+
throw std::domain_error(callback_error);
485+
}
486+
487+
// check error status
488+
if (output_frames_gen == 0) {
489+
error_handler(err_code);
490+
}
491+
492+
// Create output ndarray with proper shape and stride
493+
size_t output_shape[2];
494+
int64_t output_stride[2];
495+
496+
// if there is only one channel and the input array had only on dimension
497+
// we also output a 1D array
498+
if (_channels == 1 && _buffer_ndim == 1) {
499+
output_shape[0] = output_frames_gen;
500+
output_stride[0] = sizeof(float);
501+
502+
return nb::ndarray<nb::numpy, float>(
503+
output_data,
504+
1,
505+
output_shape,
506+
owner,
507+
output_stride
508+
);
509+
} else {
510+
output_shape[0] = output_frames_gen;
511+
output_shape[1] = _channels;
512+
output_stride[0] = _channels * sizeof(float);
513+
output_stride[1] = sizeof(float);
514+
515+
return nb::ndarray<nb::numpy, float>(
516+
output_data,
517+
2,
518+
output_shape,
519+
owner,
520+
output_stride
521+
);
522+
}
523+
}
524+
525+
void set_starting_ratio(double new_ratio) {
526+
error_handler(src_set_ratio(_state, new_ratio));
527+
_ratio = new_ratio;
528+
}
529+
530+
void reset() { error_handler(src_reset(_state)); }
531+
532+
CallbackResampler clone() const { return CallbackResampler(*this); }
533+
CallbackResampler &__enter__() { return *this; }
534+
void __exit__(const nb::object &/*exc_type*/, const nb::object &/*exc*/,
535+
const nb::object &/*exc_tb*/) {
536+
_destroy();
537+
}
538+
};
539+
540+
namespace {
541+
542+
long the_callback_func(void *cb_data, float **data) {
543+
CallbackResampler *cb = static_cast<CallbackResampler *>(cb_data);
544+
int cb_channels = cb->get_channels();
545+
546+
size_t ndim = 0;
547+
size_t num_frames = 0;
548+
float* data_ptr = nullptr;
549+
550+
{
551+
nb::gil_scoped_acquire acquire;
552+
553+
// get the data as a numpy array
554+
auto input = cb->callback();
555+
ndim = input.ndim();
556+
557+
// end of stream is signaled by a None, which is cast to a ndarray with ndim == 0
558+
if (ndim == 0) return 0;
559+
560+
num_frames = input.shape(0);
561+
data_ptr = const_cast<float*>(input.data());
562+
}
563+
564+
// set the number of channels
565+
int channels = 1;
566+
if (ndim == 2) {
567+
channels = cb->get_buffer().shape(1);
568+
} else if (ndim > 2) {
569+
// Cannot throw exception in C callback - store error and return 0
570+
cb->set_callback_error("Input array should have at most 2 dimensions");
571+
return 0;
572+
}
573+
574+
if (channels != cb_channels || channels == 0) {
575+
// Cannot throw exception in C callback - store error and return 0
576+
cb->set_callback_error("Invalid number of channels in input data.");
577+
return 0;
578+
}
579+
580+
*data = data_ptr;
581+
582+
return (long)num_frames;
583+
}
584+
585+
} // namespace
586+
358587
} // namespace samplerate
359588

360589
namespace sr = samplerate;
@@ -486,9 +715,66 @@ NB_MODULE(samplerate, m) {
486715
.def_ro("channels", &sr::Resampler::_channels,
487716
"Number of channels.");
488717

718+
nb::class_<sr::CallbackResampler>(m_converters, "CallbackResampler",
719+
R"mydelimiter(
720+
CallbackResampler.
721+
722+
Parameters
723+
----------
724+
callback : function
725+
Function that returns new frames on each call, or `None` otherwise.
726+
Input data with one or more channels is represented as a 2D array of shape
727+
(`num_frames`, `num_channels`).
728+
A single channel can be provided as a 1D array of `num_frames` length.
729+
For use with `libsamplerate`, `input_data` is converted to 32-bit float and
730+
C (row-major) memory order.
731+
ratio : float
732+
Conversion ratio = output sample rate / input sample rate.
733+
converter_type : ConverterType, str, or int
734+
Sample rate converter.
735+
channels : int
736+
Number of channels.
737+
)mydelimiter")
738+
.def(nb::init<const callback_t &, double, const nb::object &, int>(),
739+
"callback"_a, "ratio"_a, "converter_type"_a = "sinc_best",
740+
"channels"_a = 1)
741+
.def(nb::init<sr::CallbackResampler>())
742+
.def("read", &sr::CallbackResampler::read, R"mydelimiter(
743+
Read a number of frames from the resampler.
744+
745+
Parameters
746+
----------
747+
num_frames : int
748+
Number of frames to read.
749+
750+
Returns
751+
-------
752+
output_data : ndarray
753+
Resampled frames as a (`num_output_frames`, `num_channels`) or
754+
(`num_output_frames`,) array. Note that this may return fewer frames
755+
than requested, for example when no more input is available.
756+
)mydelimiter",
757+
"num_frames"_a)
758+
.def("reset", &sr::CallbackResampler::reset, "Reset state.")
759+
.def("set_starting_ratio", &sr::CallbackResampler::set_starting_ratio,
760+
"Set the starting conversion ratio for the next `read` call.")
761+
.def("clone", &sr::CallbackResampler::clone,
762+
"Create a copy of the resampler object.")
763+
.def("__enter__", &sr::CallbackResampler::__enter__,
764+
nb::rv_policy::reference_internal)
765+
.def("__exit__", &sr::CallbackResampler::__exit__)
766+
.def_rw(
767+
"ratio", &sr::CallbackResampler::_ratio,
768+
"Conversion ratio = output sample rate / input sample rate.")
769+
.def_ro("converter_type", &sr::CallbackResampler::_converter_type,
770+
"Converter type.")
771+
.def_ro("channels", &sr::CallbackResampler::_channels,
772+
"Number of channels.");
773+
489774
// Convenience imports
490775
m.attr("ResamplingError") = m_exceptions.attr("ResamplingError");
491776
m.attr("resample") = m_converters.attr("resample");
492777
m.attr("Resampler") = m_converters.attr("Resampler");
778+
m.attr("CallbackResampler") = m_converters.attr("CallbackResampler");
493779
m.attr("ConverterType") = m_converters.attr("ConverterType");
494780
}

0 commit comments

Comments
 (0)