@@ -355,6 +355,235 @@ class Resampler {
355355 Resampler clone () const { return Resampler (*this ); }
356356};
357357
358+ namespace {
359+
360+ long the_callback_func (void *cb_data, float **data);
361+
362+ } // namespace
363+
364+ class CallbackResampler {
365+ private:
366+ SRC_STATE *_state = nullptr ;
367+ callback_t _callback = nullptr ;
368+ nb_array_f32 _current_buffer;
369+ size_t _buffer_ndim = 0 ;
370+ std::string _callback_error_msg = " " ;
371+
372+ public:
373+ double _ratio = 0.0 ;
374+ int _converter_type = 0 ;
375+ size_t _channels = 0 ;
376+
377+ private:
378+ void _create () {
379+ int _err_num = 0 ;
380+ _state = src_callback_new (the_callback_func, _converter_type, (int )_channels,
381+ &_err_num, static_cast <void *>(this ));
382+ if (_state == nullptr ) error_handler (_err_num);
383+ }
384+
385+ void _destroy () {
386+ if (_state != nullptr ) {
387+ src_delete (_state);
388+ _state = nullptr ;
389+ }
390+ }
391+
392+ public:
393+ CallbackResampler (const callback_t &callback_func, double ratio,
394+ const nb::object &converter_type, size_t channels)
395+ : _callback(callback_func),
396+ _ratio (ratio),
397+ _converter_type(get_converter_type(converter_type)),
398+ _channels(channels) {
399+ _create ();
400+ }
401+
402+ // copy constructor
403+ CallbackResampler (const CallbackResampler &r)
404+ : _callback(r._callback),
405+ _ratio(r._ratio),
406+ _converter_type(r._converter_type),
407+ _channels(r._channels) {
408+ int _err_num = 0 ;
409+ _state = src_clone (r._state , &_err_num);
410+ if (_state == nullptr ) error_handler (_err_num);
411+ }
412+
413+ // move constructor
414+ CallbackResampler (CallbackResampler &&r)
415+ : _state(r._state),
416+ _callback(r._callback),
417+ _current_buffer(std::move(r._current_buffer)),
418+ _buffer_ndim(r._buffer_ndim),
419+ _callback_error_msg(std::move(r._callback_error_msg)),
420+ _ratio(r._ratio),
421+ _converter_type(r._converter_type),
422+ _channels(r._channels) {
423+ r._state = nullptr ;
424+ r._callback = nullptr ;
425+ r._buffer_ndim = 0 ;
426+ r._ratio = 0.0 ;
427+ r._converter_type = 0 ;
428+ r._channels = 0 ;
429+ }
430+
431+ ~CallbackResampler () { _destroy (); }
432+
433+ void set_buffer (const nb_array_f32 &new_buf) { _current_buffer = new_buf; }
434+ nb_array_f32 get_buffer () const { return _current_buffer; }
435+ size_t get_channels () { return _channels; }
436+ void set_callback_error (const std::string &error_msg) {
437+ _callback_error_msg = error_msg;
438+ }
439+ std::string get_callback_error () const { return _callback_error_msg; }
440+ void clear_callback_error () { _callback_error_msg = " " ; }
441+
442+ nb_array_f32 callback (void ) {
443+ auto input = _callback ();
444+
445+ if (input.ndim () > 0 && _buffer_ndim == 0 )
446+ _buffer_ndim = input.ndim ();
447+
448+ _current_buffer = input;
449+ return input;
450+ }
451+
452+ nb::ndarray<nb::numpy, float > read (size_t frames) {
453+ // Allocate output array
454+ size_t total_elements = frames * _channels;
455+ float * output_data = new float [total_elements];
456+
457+ // Create capsule for memory management
458+ nb::capsule owner (output_data, [](void * p) noexcept {
459+ delete[] static_cast <float *>(p);
460+ });
461+
462+ if (_state == nullptr ) _create ();
463+
464+ // clear any previous callback error
465+ clear_callback_error ();
466+
467+ // read from the callback - note: GIL is managed by the_callback_func
468+ // which acquires it only when calling the Python callback
469+ size_t output_frames_gen = 0 ;
470+ int err_code = 0 ;
471+ {
472+ nb::gil_scoped_release release;
473+ output_frames_gen = src_callback_read (_state, _ratio, (long )frames,
474+ output_data);
475+ // Get error code while GIL is released
476+ if (output_frames_gen == 0 ) {
477+ err_code = src_error (_state);
478+ }
479+ }
480+
481+ // check if callback had an error
482+ std::string callback_error = get_callback_error ();
483+ if (!callback_error.empty ()) {
484+ throw std::domain_error (callback_error);
485+ }
486+
487+ // check error status
488+ if (output_frames_gen == 0 ) {
489+ error_handler (err_code);
490+ }
491+
492+ // Create output ndarray with proper shape and stride
493+ size_t output_shape[2 ];
494+ int64_t output_stride[2 ];
495+
496+ // if there is only one channel and the input array had only on dimension
497+ // we also output a 1D array
498+ if (_channels == 1 && _buffer_ndim == 1 ) {
499+ output_shape[0 ] = output_frames_gen;
500+ output_stride[0 ] = sizeof (float );
501+
502+ return nb::ndarray<nb::numpy, float >(
503+ output_data,
504+ 1 ,
505+ output_shape,
506+ owner,
507+ output_stride
508+ );
509+ } else {
510+ output_shape[0 ] = output_frames_gen;
511+ output_shape[1 ] = _channels;
512+ output_stride[0 ] = _channels * sizeof (float );
513+ output_stride[1 ] = sizeof (float );
514+
515+ return nb::ndarray<nb::numpy, float >(
516+ output_data,
517+ 2 ,
518+ output_shape,
519+ owner,
520+ output_stride
521+ );
522+ }
523+ }
524+
525+ void set_starting_ratio (double new_ratio) {
526+ error_handler (src_set_ratio (_state, new_ratio));
527+ _ratio = new_ratio;
528+ }
529+
530+ void reset () { error_handler (src_reset (_state)); }
531+
532+ CallbackResampler clone () const { return CallbackResampler (*this ); }
533+ CallbackResampler &__enter__ () { return *this ; }
534+ void __exit__ (const nb::object &/* exc_type*/ , const nb::object &/* exc*/ ,
535+ const nb::object &/* exc_tb*/ ) {
536+ _destroy ();
537+ }
538+ };
539+
540+ namespace {
541+
542+ long the_callback_func (void *cb_data, float **data) {
543+ CallbackResampler *cb = static_cast <CallbackResampler *>(cb_data);
544+ int cb_channels = cb->get_channels ();
545+
546+ size_t ndim = 0 ;
547+ size_t num_frames = 0 ;
548+ float * data_ptr = nullptr ;
549+
550+ {
551+ nb::gil_scoped_acquire acquire;
552+
553+ // get the data as a numpy array
554+ auto input = cb->callback ();
555+ ndim = input.ndim ();
556+
557+ // end of stream is signaled by a None, which is cast to a ndarray with ndim == 0
558+ if (ndim == 0 ) return 0 ;
559+
560+ num_frames = input.shape (0 );
561+ data_ptr = const_cast <float *>(input.data ());
562+ }
563+
564+ // set the number of channels
565+ int channels = 1 ;
566+ if (ndim == 2 ) {
567+ channels = cb->get_buffer ().shape (1 );
568+ } else if (ndim > 2 ) {
569+ // Cannot throw exception in C callback - store error and return 0
570+ cb->set_callback_error (" Input array should have at most 2 dimensions" );
571+ return 0 ;
572+ }
573+
574+ if (channels != cb_channels || channels == 0 ) {
575+ // Cannot throw exception in C callback - store error and return 0
576+ cb->set_callback_error (" Invalid number of channels in input data." );
577+ return 0 ;
578+ }
579+
580+ *data = data_ptr;
581+
582+ return (long )num_frames;
583+ }
584+
585+ } // namespace
586+
358587} // namespace samplerate
359588
360589namespace sr = samplerate;
@@ -486,9 +715,66 @@ NB_MODULE(samplerate, m) {
486715 .def_ro (" channels" , &sr::Resampler::_channels,
487716 " Number of channels." );
488717
718+ nb::class_<sr::CallbackResampler>(m_converters, " CallbackResampler" ,
719+ R"mydelimiter(
720+ CallbackResampler.
721+
722+ Parameters
723+ ----------
724+ callback : function
725+ Function that returns new frames on each call, or `None` otherwise.
726+ Input data with one or more channels is represented as a 2D array of shape
727+ (`num_frames`, `num_channels`).
728+ A single channel can be provided as a 1D array of `num_frames` length.
729+ For use with `libsamplerate`, `input_data` is converted to 32-bit float and
730+ C (row-major) memory order.
731+ ratio : float
732+ Conversion ratio = output sample rate / input sample rate.
733+ converter_type : ConverterType, str, or int
734+ Sample rate converter.
735+ channels : int
736+ Number of channels.
737+ )mydelimiter" )
738+ .def (nb::init<const callback_t &, double , const nb::object &, int >(),
739+ " callback" _a, " ratio" _a, " converter_type" _a = " sinc_best" ,
740+ " channels" _a = 1 )
741+ .def (nb::init<sr::CallbackResampler>())
742+ .def (" read" , &sr::CallbackResampler::read, R"mydelimiter(
743+ Read a number of frames from the resampler.
744+
745+ Parameters
746+ ----------
747+ num_frames : int
748+ Number of frames to read.
749+
750+ Returns
751+ -------
752+ output_data : ndarray
753+ Resampled frames as a (`num_output_frames`, `num_channels`) or
754+ (`num_output_frames`,) array. Note that this may return fewer frames
755+ than requested, for example when no more input is available.
756+ )mydelimiter" ,
757+ " num_frames" _a)
758+ .def (" reset" , &sr::CallbackResampler::reset, " Reset state." )
759+ .def (" set_starting_ratio" , &sr::CallbackResampler::set_starting_ratio,
760+ " Set the starting conversion ratio for the next `read` call." )
761+ .def (" clone" , &sr::CallbackResampler::clone,
762+ " Create a copy of the resampler object." )
763+ .def (" __enter__" , &sr::CallbackResampler::__enter__,
764+ nb::rv_policy::reference_internal)
765+ .def (" __exit__" , &sr::CallbackResampler::__exit__)
766+ .def_rw (
767+ " ratio" , &sr::CallbackResampler::_ratio,
768+ " Conversion ratio = output sample rate / input sample rate." )
769+ .def_ro (" converter_type" , &sr::CallbackResampler::_converter_type,
770+ " Converter type." )
771+ .def_ro (" channels" , &sr::CallbackResampler::_channels,
772+ " Number of channels." );
773+
489774 // Convenience imports
490775 m.attr (" ResamplingError" ) = m_exceptions.attr (" ResamplingError" );
491776 m.attr (" resample" ) = m_converters.attr (" resample" );
492777 m.attr (" Resampler" ) = m_converters.attr (" Resampler" );
778+ m.attr (" CallbackResampler" ) = m_converters.attr (" CallbackResampler" );
493779 m.attr (" ConverterType" ) = m_converters.attr (" ConverterType" );
494780}
0 commit comments