22import concurrent .futures
33import logging
44import inspect
5+ import os
56import time
67import traceback
78import threading
9+ import weakref
810
911from abc import abstractmethod
1012from collections import defaultdict
1113from collections .abc import Callable , Coroutine , Generator , Sequence
1214from functools import wraps , partial
13- from copy import deepcopy
15+ from concurrent .futures import ThreadPoolExecutor
16+ from concurrent .futures .thread import _worker
1417from multiprocessing import Process
1518from multiprocessing .synchronize import Event as EventType
1619from multiprocessing .synchronize import Barrier as BarrierType
2023
2124from .stream import Stream , InputStream , OutputStream
2225from .unit import Unit , TIMEIT_ATTR , SUBSCRIBES_ATTR
23- from .messagechannel import LeakyQueue
2426
2527from .graphcontext import GraphContext
2628from .pubclient import Publisher
2931
3032logger = logging .getLogger ("ezmsg" )
3133
34+ STRICT_SHUTDOWN_ENV = "EZMSG_STRICT_SHUTDOWN"
35+
36+
37+ def _strict_shutdown_enabled () -> bool :
38+ value = os .environ .get (STRICT_SHUTDOWN_ENV , "" )
39+ return value .lower () in ("1" , "true" , "yes" , "on" )
40+
41+
42+ class _DaemonThreadPoolExecutor (ThreadPoolExecutor ):
43+ def _adjust_thread_count (self ) -> None :
44+ if self ._broken :
45+ return
46+ num_threads = len (self ._threads )
47+ if num_threads >= self ._max_workers :
48+ return
49+ thread_name = f"{ self ._thread_name_prefix or 'ThreadPool' } _{ num_threads } "
50+ thread = threading .Thread (
51+ name = thread_name ,
52+ target = _worker ,
53+ args = (
54+ weakref .ref (self ),
55+ self ._work_queue ,
56+ self ._initializer ,
57+ self ._initargs ,
58+ ),
59+ )
60+ thread .daemon = True
61+ thread .start ()
62+ self ._threads .add (thread )
63+
3264
3365class Complete (Exception ):
3466 """
@@ -228,9 +260,14 @@ async def setup_state():
228260 ),
229261 loop = loop ,
230262 ).result ()
263+
264+ except asyncio .CancelledError :
265+ pass
266+
231267 except Exception :
232268 self .start_barrier .abort ()
233- logger .error (f"{ traceback .format_exc ()} " )
269+ # logger.error(f"{traceback.format_exc()}")
270+ raise
234271
235272 try :
236273 logger .debug ("Waiting at start barrier!" )
@@ -260,9 +297,9 @@ async def coro_wrapper(coro):
260297 fn (unit )
261298 except NormalTermination :
262299 self .term_ev .set ()
263- except Exception :
264- logger .error (f"Exception in Main: { unit .address } " )
265- logger .error (traceback .format_exc ())
300+ # except Exception:
301+ # logger.error(f"Exception in Main: {unit.address}")
302+ # logger.error(traceback.format_exc())
266303
267304 while True :
268305 try :
@@ -352,6 +389,7 @@ def task_wrapper(
352389 self , unit : Unit , task : Callable
353390 ) -> Callable [..., Coroutine [Any , Any , None ]]:
354391 task_address = f"{ unit .address } :{ task .__name__ } "
392+ strict_shutdown = _strict_shutdown_enabled ()
355393
356394 async def publish (stream : Stream , obj : Any ) -> None :
357395 if stream .address in self .pubs :
@@ -400,9 +438,15 @@ async def wrapped_task(msg: Any = None) -> None:
400438 self .term_ev .set ()
401439 raise
402440
441+ except asyncio .CancelledError :
442+ # Normal during shutdown; propagate without logging.
443+ raise
444+
403445 except Exception :
404446 logger .error (f"Exception in Task: { task_address } " )
405447 logger .error (traceback .format_exc ())
448+ if strict_shutdown :
449+ raise
406450
407451 return wrapped_task
408452
@@ -491,6 +535,24 @@ def new_threaded_event_loop(
491535 :rtype: Generator[asyncio.AbstractEventLoop, None, None]
492536 """
493537 loop = asyncio .new_event_loop ()
538+ strict_shutdown = _strict_shutdown_enabled ()
539+ shutdown_suppress = threading .Event ()
540+ suppressed_shutdown_errors = {"count" : 0 }
541+ suppressed_lock = threading .Lock ()
542+ if not strict_shutdown :
543+ loop .set_default_executor (
544+ _DaemonThreadPoolExecutor (thread_name_prefix = "EZMSG" )
545+ )
546+ def _loop_exception_handler (
547+ loop_obj : asyncio .AbstractEventLoop , context : dict
548+ ) -> None :
549+ if shutdown_suppress .is_set ():
550+ with suppressed_lock :
551+ suppressed_shutdown_errors ["count" ] += 1
552+ return
553+ loop_obj .default_exception_handler (context )
554+
555+ loop .set_exception_handler (_loop_exception_handler )
494556 thread = threading .Thread (target = run_loop , name = "TaskThread" , args = (loop ,))
495557 thread .start ()
496558
@@ -502,6 +564,49 @@ def new_threaded_event_loop(
502564 logger .debug ("Waiting at event..." )
503565 # ev.wait()
504566 logger .debug ("Stopping and closing task thread" )
567+
568+ if not strict_shutdown :
569+ shutdown_suppress .set ()
570+ # Cancel and await remaining tasks before stopping the loop.
571+ async def _cancel_remaining () -> int :
572+ tasks = [
573+ t
574+ for t in asyncio .all_tasks ()
575+ if t is not asyncio .current_task () and not t .done ()
576+ ]
577+ for t in tasks :
578+ t .cancel ()
579+ if tasks :
580+ await asyncio .wait (tasks )
581+ return len (tasks )
582+
583+ cancelled_count = 0
584+ forced_interrupt = False
585+ fut = asyncio .run_coroutine_threadsafe (_cancel_remaining (), loop )
586+ try :
587+ cancelled_count = fut .result ()
588+ except KeyboardInterrupt :
589+ forced_interrupt = True
590+ fut .cancel ()
591+ except Exception :
592+ cancelled_count = 0
593+
594+ suppressed_count = suppressed_shutdown_errors ["count" ]
595+ if cancelled_count or suppressed_count or forced_interrupt :
596+ if forced_interrupt and not cancelled_count and not suppressed_count :
597+ logger .warning (
598+ "Shutdown interrupted; tasks may still be running. "
599+ "Re-run with EZMSG_STRICT_SHUTDOWN=1 to debug tasks with poor shutdown behavior."
600+ )
601+ else :
602+ logger .warning (
603+ "Shutdown suppressed %d error(s) and cancelled %d task(s). "
604+ "Shutdown was NOT clean; re-run with EZMSG_STRICT_SHUTDOWN=1 "
605+ "to debug tasks with poor shutdown behavior." ,
606+ suppressed_count ,
607+ cancelled_count ,
608+ )
609+
505610 loop .call_soon_threadsafe (loop .stop )
506611 thread .join ()
507612 loop .close ()
0 commit comments