99import weakref
1010
1111from abc import abstractmethod
12+ from dataclasses import dataclass
1213from collections import defaultdict
1314from collections .abc import Callable , Coroutine , Generator , Sequence
1415from functools import wraps , partial
@@ -39,7 +40,30 @@ def _strict_shutdown_enabled() -> bool:
3940 return value .lower () in ("1" , "true" , "yes" , "on" )
4041
4142
43+ @dataclass
44+ class ShutdownSummary :
45+ cancelled_tasks : int = 0
46+ executor_active : int = 0
47+ pending_tasks : int = 0
48+ suppressed_errors : int = 0
49+ forced_interrupt : bool = False
50+
51+ @property
52+ def unclean (self ) -> bool :
53+ return bool (
54+ self .executor_active
55+ or self .pending_tasks
56+ or self .suppressed_errors
57+ or self .forced_interrupt
58+ )
59+
60+
4261class _DaemonThreadPoolExecutor (ThreadPoolExecutor ):
62+ def __init__ (self , * args , ** kwargs ) -> None :
63+ super ().__init__ (* args , ** kwargs )
64+ self ._active_count = 0
65+ self ._active_lock = threading .Lock ()
66+
4367 def _adjust_thread_count (self ) -> None :
4468 if self ._broken :
4569 return
@@ -61,6 +85,22 @@ def _adjust_thread_count(self) -> None:
6185 thread .start ()
6286 self ._threads .add (thread )
6387
88+ def submit (self , fn , / , * args , ** kwargs ):
89+ fut = super ().submit (fn , * args , ** kwargs )
90+ with self ._active_lock :
91+ self ._active_count += 1
92+
93+ def _decrement (_ ):
94+ with self ._active_lock :
95+ self ._active_count -= 1
96+
97+ fut .add_done_callback (_decrement )
98+ return fut
99+
100+ def active_count (self ) -> int :
101+ with self ._active_lock :
102+ return self ._active_count
103+
64104
65105class Complete (Exception ):
66106 """
@@ -178,11 +218,13 @@ class DefaultBackendProcess(BackendProcess):
178218 """
179219
180220 pubs : dict [str , Publisher ]
221+ _shutdown_errors : bool
181222
182223 def process (self , loop : asyncio .AbstractEventLoop ) -> None :
183224 main_func = None
184225 context = GraphContext (self .graph_address )
185226 coro_callables : dict [str , Callable [[], Coroutine [Any , Any , None ]]] = dict ()
227+ self ._shutdown_errors = False
186228
187229 try :
188230 self .pubs = dict ()
@@ -445,6 +487,8 @@ async def wrapped_task(msg: Any = None) -> None:
445487 except Exception :
446488 logger .error (f"Exception in Task: { task_address } " )
447489 logger .error (traceback .format_exc ())
490+ if self .term_ev .is_set ():
491+ self ._shutdown_errors = True
448492 if strict_shutdown :
449493 raise
450494
@@ -522,6 +566,7 @@ def run_loop(loop: asyncio.AbstractEventLoop):
522566@contextmanager
523567def new_threaded_event_loop (
524568 ev : threading .Event | None = None ,
569+ shutdown_summary : ShutdownSummary | None = None ,
525570) -> Generator [asyncio .AbstractEventLoop , None , None ]:
526571 """
527572 Create a new asyncio event loop running in a separate thread.
@@ -531,6 +576,8 @@ def new_threaded_event_loop(
531576
532577 :param ev: Optional event to signal when the loop is ready.
533578 :type ev: threading.Event | None
579+ :param shutdown_summary: Optional shutdown summary object to populate.
580+ :type shutdown_summary: ShutdownSummary | None
534581 :return: Context manager yielding the event loop.
535582 :rtype: Generator[asyncio.AbstractEventLoop, None, None]
536583 """
@@ -539,10 +586,10 @@ def new_threaded_event_loop(
539586 shutdown_suppress = threading .Event ()
540587 suppressed_shutdown_errors = {"count" : 0 }
541588 suppressed_lock = threading .Lock ()
589+ executor = None
542590 if not strict_shutdown :
543- loop .set_default_executor (
544- _DaemonThreadPoolExecutor (thread_name_prefix = "EZMSG" )
545- )
591+ executor = _DaemonThreadPoolExecutor (thread_name_prefix = "EZMSG" )
592+ loop .set_default_executor (executor )
546593 def _loop_exception_handler (
547594 loop_obj : asyncio .AbstractEventLoop , context : dict
548595 ) -> None :
@@ -568,36 +615,45 @@ def _loop_exception_handler(
568615 if not strict_shutdown :
569616 shutdown_suppress .set ()
570617 # Cancel and await remaining tasks before stopping the loop.
571- async def _cancel_remaining () -> int :
618+ async def _cancel_remaining (timeout : float = 1.0 ) -> tuple [ int , int ] :
572619 tasks = [
573620 t
574621 for t in asyncio .all_tasks ()
575622 if t is not asyncio .current_task () and not t .done ()
576623 ]
577624 for t in tasks :
578625 t .cancel ()
579- if tasks :
580- await asyncio .wait (tasks )
581- return len (tasks )
626+ if not tasks :
627+ return 0 , 0
628+ _ , pending = await asyncio .wait (tasks , timeout = timeout )
629+ return len (tasks ), len (pending )
582630
583631 cancelled_count = 0
632+ pending_count = 0
584633 forced_interrupt = False
585634 fut = asyncio .run_coroutine_threadsafe (_cancel_remaining (), loop )
586635 try :
587- cancelled_count = fut .result ()
636+ cancelled_count , pending_count = fut .result ()
588637 except KeyboardInterrupt :
589638 forced_interrupt = True
590639 fut .cancel ()
591640 except Exception :
592641 cancelled_count = 0
642+ pending_count = 0
593643
594644 suppressed_count = suppressed_shutdown_errors ["count" ]
595- if cancelled_count or suppressed_count or forced_interrupt :
645+ if cancelled_count or suppressed_count or forced_interrupt or pending_count :
596646 if forced_interrupt and not cancelled_count and not suppressed_count :
597647 logger .warning (
598648 "Shutdown interrupted; tasks may still be running. "
599649 "Re-run with EZMSG_STRICT_SHUTDOWN=1 to debug tasks with poor shutdown behavior."
600650 )
651+ elif pending_count :
652+ logger .warning (
653+ "Shutdown timed out waiting for %d task(s). "
654+ "Re-run with EZMSG_STRICT_SHUTDOWN=1 to debug tasks with poor shutdown behavior." ,
655+ pending_count ,
656+ )
601657 else :
602658 logger .warning (
603659 "Shutdown suppressed %d error(s) and cancelled %d task(s). "
@@ -607,6 +663,15 @@ async def _cancel_remaining() -> int:
607663 cancelled_count ,
608664 )
609665
666+ if shutdown_summary is not None :
667+ shutdown_summary .cancelled_tasks = cancelled_count
668+ shutdown_summary .pending_tasks = pending_count
669+ shutdown_summary .executor_active = (
670+ executor .active_count () if executor is not None else 0
671+ )
672+ shutdown_summary .suppressed_errors = suppressed_count
673+ shutdown_summary .forced_interrupt = forced_interrupt
674+
610675 loop .call_soon_threadsafe (loop .stop )
611676 thread .join ()
612677 loop .close ()
0 commit comments