@@ -23,6 +23,7 @@ use opsqueue::{
2323use ux_serde:: u63;
2424
2525use crate :: {
26+ async_util,
2627 common:: { run_unless_interrupted, start_runtime, SubmissionId , SubmissionStatus } ,
2728 errors:: { self , CError , CPyResult , FatalPythonException } ,
2829} ;
@@ -362,9 +363,9 @@ impl ProducerClient {
362363 ) -> PyResult < Bound < ' p , PyAny > > {
363364 let me = self . clone ( ) ;
364365 let _tokio_active_runtime_guard = me. runtime . enter ( ) ;
365- crate :: async_util:: future_into_py (
366+ async_util:: future_into_py (
366367 py,
367- crate :: async_util:: async_allow_threads ( Box :: pin ( async move {
368+ async_util:: async_allow_threads ( Box :: pin ( async move {
368369 match me. stream_completed_submission_chunks ( submission_id) . await {
369370 Ok ( iter) => {
370371 let async_iter = PyChunksAsyncIter :: from ( iter) ;
@@ -467,7 +468,7 @@ pub type ChunksStream = BoxStream<'static, CPyResult<Vec<u8>, ChunkRetrievalErro
467468
468469#[ pyclass]
469470pub struct PyChunksIter {
470- stream : tokio:: sync:: Mutex < ChunksStream > ,
471+ stream : Arc < tokio:: sync:: Mutex < ChunksStream > > ,
471472 runtime : Arc < tokio:: runtime:: Runtime > ,
472473}
473474
@@ -480,7 +481,7 @@ impl PyChunksIter {
480481 . map_err ( CError )
481482 . boxed ( ) ;
482483 Self {
483- stream : tokio:: sync:: Mutex :: new ( stream) ,
484+ stream : Arc :: new ( tokio:: sync:: Mutex :: new ( stream) ) ,
484485 runtime : client. runtime . clone ( ) ,
485486 }
486487 }
@@ -492,11 +493,21 @@ impl PyChunksIter {
492493 slf
493494 }
494495
495- fn __next__ ( mut slf : PyRefMut < ' _ , Self > ) -> Option < CPyResult < Vec < u8 > , ChunkRetrievalError > > {
496- let me = & mut * slf;
497- let runtime = & mut me. runtime ;
498- let stream = & mut me. stream ;
499- runtime. block_on ( async { stream. lock ( ) . await . next ( ) . await } )
496+ fn __next__ ( & self , py : Python < ' _ > ) -> Option < CPyResult < Vec < u8 > , ChunkRetrievalError > > {
497+ // The only time we need the GIL is when turning the result back.
498+ // By unlocking here, we reduce the chance of deadlocks.
499+ py. allow_threads ( move || {
500+ let runtime = self . runtime . clone ( ) ;
501+ let stream = self . stream . clone ( ) ;
502+ runtime. block_on ( async {
503+ // We lock the stream in a separate Tokio task
504+ // that explicitly runs on the runtime thread rather than on the main Python thread.
505+ // This reduces the possibility for deadlocks even further.
506+ tokio:: task:: spawn ( async move { stream. lock ( ) . await . next ( ) . await } )
507+ . await
508+ . expect ( "Top-level spawn to succeed" )
509+ } )
510+ } )
500511 }
501512
502513 fn __aiter__ ( slf : PyRef < ' _ , Self > ) -> PyRef < ' _ , Self > {
@@ -513,7 +524,7 @@ pub struct PyChunksAsyncIter {
513524impl From < PyChunksIter > for PyChunksAsyncIter {
514525 fn from ( iter : PyChunksIter ) -> Self {
515526 Self {
516- stream : Arc :: new ( iter. stream ) ,
527+ stream : iter. stream ,
517528 runtime : iter. runtime ,
518529 }
519530 }
@@ -525,16 +536,27 @@ impl PyChunksAsyncIter {
525536 slf
526537 }
527538
528- fn __anext__ ( slf : PyRef < ' _ , Self > ) -> PyResult < Bound < ' _ , PyAny > > {
529- let _tokio_active_runtime_guard = slf. runtime . enter ( ) ;
530- let stream = slf. stream . clone ( ) ;
531- pyo3_async_runtimes:: tokio:: future_into_py ( slf. py ( ) , async move {
532- let res = stream. lock ( ) . await . next ( ) . await ;
533- match res {
534- None => Err ( PyStopAsyncIteration :: new_err ( ( ) ) ) ,
535- Some ( Ok ( val) ) => Ok ( Some ( val) ) ,
536- Some ( Err ( e) ) => Err ( e. into ( ) ) ,
537- }
538- } )
539+ fn __anext__ < ' py > ( & self , py : Python < ' py > ) -> PyResult < Bound < ' py , PyAny > > {
540+ let stream = self . stream . clone ( ) ;
541+ let _tokio_active_runtime_guard = self . runtime . enter ( ) ;
542+
543+ async_util:: future_into_py (
544+ py,
545+ // The only time we need the GIL is when turning the result into Python datatypes.
546+ // By unlocking here, we reduce the chance of deadlocks.
547+ async_util:: async_allow_threads ( Box :: pin ( async move {
548+ // We lock the stream in a separate Tokio task
549+ // that explicitly runs on the runtime thread rather than on the main Python thread.
550+ // This reduces the possibility for deadlocks even further.
551+ let res = tokio:: task:: spawn ( async move { stream. lock ( ) . await . next ( ) . await } )
552+ . await
553+ . expect ( "Top-level spawn to succeed" ) ;
554+ match res {
555+ None => Err ( PyStopAsyncIteration :: new_err ( ( ) ) ) ,
556+ Some ( Ok ( val) ) => Ok ( Some ( val) ) ,
557+ Some ( Err ( e) ) => Err ( e. into ( ) ) ,
558+ }
559+ } ) ) ,
560+ )
539561 }
540562}
0 commit comments