Skip to content

Commit 4cec249

Browse files
Merge pull request #220 from ezmsg-org/fix/219-shutdown-pending
Fix: Multiple shutdown issues and enhance shutdown UX
2 parents fa5bc59 + 5a8b6da commit 4cec249

6 files changed

Lines changed: 557 additions & 7 deletions

File tree

src/ezmsg/core/backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,15 @@ def graph_address(self) -> AddressType | None:
222222
return self._graph_context.graph_address
223223
return self._graph_address
224224

225+
@property
226+
def strict_shutdown(self) -> bool:
227+
value = os.environ.get("EZMSG_STRICT_SHUTDOWN", "")
228+
return value.lower() in ("1", "true", "yes", "on")
229+
230+
@strict_shutdown.setter
231+
def strict_shutdown(self, value: bool) -> None:
232+
os.environ["EZMSG_STRICT_SHUTDOWN"] = "1" if value else "0"
233+
225234
@property
226235
def graph_server_spawned(self) -> bool:
227236
return self._graph_server_spawned

src/ezmsg/core/backendprocess.py

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
import concurrent.futures
33
import logging
44
import inspect
5+
import os
56
import time
67
import traceback
78
import threading
9+
import weakref
810

911
from abc import abstractmethod
1012
from collections import defaultdict
1113
from collections.abc import Callable, Coroutine, Generator, Sequence
1214
from functools import wraps, partial
13-
from copy import deepcopy
15+
from concurrent.futures import ThreadPoolExecutor
16+
from concurrent.futures.thread import _worker
1417
from multiprocessing import Process
1518
from multiprocessing.synchronize import Event as EventType
1619
from multiprocessing.synchronize import Barrier as BarrierType
@@ -20,7 +23,6 @@
2023

2124
from .stream import Stream, InputStream, OutputStream
2225
from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR
23-
from .messagechannel import LeakyQueue
2426

2527
from .graphcontext import GraphContext
2628
from .pubclient import Publisher
@@ -29,6 +31,36 @@
2931

3032
logger = logging.getLogger("ezmsg")
3133

34+
STRICT_SHUTDOWN_ENV = "EZMSG_STRICT_SHUTDOWN"
35+
36+
37+
def _strict_shutdown_enabled() -> bool:
38+
value = os.environ.get(STRICT_SHUTDOWN_ENV, "")
39+
return value.lower() in ("1", "true", "yes", "on")
40+
41+
42+
class _DaemonThreadPoolExecutor(ThreadPoolExecutor):
43+
def _adjust_thread_count(self) -> None:
44+
if self._broken:
45+
return
46+
num_threads = len(self._threads)
47+
if num_threads >= self._max_workers:
48+
return
49+
thread_name = f"{self._thread_name_prefix or 'ThreadPool'}_{num_threads}"
50+
thread = threading.Thread(
51+
name=thread_name,
52+
target=_worker,
53+
args=(
54+
weakref.ref(self),
55+
self._work_queue,
56+
self._initializer,
57+
self._initargs,
58+
),
59+
)
60+
thread.daemon = True
61+
thread.start()
62+
self._threads.add(thread)
63+
3264

3365
class Complete(Exception):
3466
"""
@@ -228,9 +260,14 @@ async def setup_state():
228260
),
229261
loop=loop,
230262
).result()
263+
264+
except asyncio.CancelledError:
265+
pass
266+
231267
except Exception:
232268
self.start_barrier.abort()
233-
logger.error(f"{traceback.format_exc()}")
269+
# logger.error(f"{traceback.format_exc()}")
270+
raise
234271

235272
try:
236273
logger.debug("Waiting at start barrier!")
@@ -260,9 +297,9 @@ async def coro_wrapper(coro):
260297
fn(unit)
261298
except NormalTermination:
262299
self.term_ev.set()
263-
except Exception:
264-
logger.error(f"Exception in Main: {unit.address}")
265-
logger.error(traceback.format_exc())
300+
# except Exception:
301+
# logger.error(f"Exception in Main: {unit.address}")
302+
# logger.error(traceback.format_exc())
266303

267304
while True:
268305
try:
@@ -352,6 +389,7 @@ def task_wrapper(
352389
self, unit: Unit, task: Callable
353390
) -> Callable[..., Coroutine[Any, Any, None]]:
354391
task_address = f"{unit.address}:{task.__name__}"
392+
strict_shutdown = _strict_shutdown_enabled()
355393

356394
async def publish(stream: Stream, obj: Any) -> None:
357395
if stream.address in self.pubs:
@@ -400,9 +438,15 @@ async def wrapped_task(msg: Any = None) -> None:
400438
self.term_ev.set()
401439
raise
402440

441+
except asyncio.CancelledError:
442+
# Normal during shutdown; propagate without logging.
443+
raise
444+
403445
except Exception:
404446
logger.error(f"Exception in Task: {task_address}")
405447
logger.error(traceback.format_exc())
448+
if strict_shutdown:
449+
raise
406450

407451
return wrapped_task
408452

@@ -491,6 +535,24 @@ def new_threaded_event_loop(
491535
:rtype: Generator[asyncio.AbstractEventLoop, None, None]
492536
"""
493537
loop = asyncio.new_event_loop()
538+
strict_shutdown = _strict_shutdown_enabled()
539+
shutdown_suppress = threading.Event()
540+
suppressed_shutdown_errors = {"count": 0}
541+
suppressed_lock = threading.Lock()
542+
if not strict_shutdown:
543+
loop.set_default_executor(
544+
_DaemonThreadPoolExecutor(thread_name_prefix="EZMSG")
545+
)
546+
def _loop_exception_handler(
547+
loop_obj: asyncio.AbstractEventLoop, context: dict
548+
) -> None:
549+
if shutdown_suppress.is_set():
550+
with suppressed_lock:
551+
suppressed_shutdown_errors["count"] += 1
552+
return
553+
loop_obj.default_exception_handler(context)
554+
555+
loop.set_exception_handler(_loop_exception_handler)
494556
thread = threading.Thread(target=run_loop, name="TaskThread", args=(loop,))
495557
thread.start()
496558

@@ -502,6 +564,49 @@ def new_threaded_event_loop(
502564
logger.debug("Waiting at event...")
503565
# ev.wait()
504566
logger.debug("Stopping and closing task thread")
567+
568+
if not strict_shutdown:
569+
shutdown_suppress.set()
570+
# Cancel and await remaining tasks before stopping the loop.
571+
async def _cancel_remaining() -> int:
572+
tasks = [
573+
t
574+
for t in asyncio.all_tasks()
575+
if t is not asyncio.current_task() and not t.done()
576+
]
577+
for t in tasks:
578+
t.cancel()
579+
if tasks:
580+
await asyncio.wait(tasks)
581+
return len(tasks)
582+
583+
cancelled_count = 0
584+
forced_interrupt = False
585+
fut = asyncio.run_coroutine_threadsafe(_cancel_remaining(), loop)
586+
try:
587+
cancelled_count = fut.result()
588+
except KeyboardInterrupt:
589+
forced_interrupt = True
590+
fut.cancel()
591+
except Exception:
592+
cancelled_count = 0
593+
594+
suppressed_count = suppressed_shutdown_errors["count"]
595+
if cancelled_count or suppressed_count or forced_interrupt:
596+
if forced_interrupt and not cancelled_count and not suppressed_count:
597+
logger.warning(
598+
"Shutdown interrupted; tasks may still be running. "
599+
"Re-run with EZMSG_STRICT_SHUTDOWN=1 to debug tasks with poor shutdown behavior."
600+
)
601+
else:
602+
logger.warning(
603+
"Shutdown suppressed %d error(s) and cancelled %d task(s). "
604+
"Shutdown was NOT clean; re-run with EZMSG_STRICT_SHUTDOWN=1 "
605+
"to debug tasks with poor shutdown behavior.",
606+
suppressed_count,
607+
cancelled_count,
608+
)
609+
505610
loop.call_soon_threadsafe(loop.stop)
506611
thread.join()
507612
loop.close()

src/ezmsg/core/netprotocol.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,10 @@ async def read_str(reader: asyncio.StreamReader) -> str:
224224

225225

226226
async def close_stream_writer(writer: asyncio.StreamWriter):
227-
writer.close()
227+
try:
228+
writer.close()
229+
except RuntimeError:
230+
return # Event loop is closed, transport is already gone
228231
# ConnectionResetError can be raised on wait_closed.
229232
# See: https://github.com/python/cpython/issues/83037
230233
try:
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import asyncio
2+
import os
3+
import time
4+
import multiprocessing as mp
5+
import signal
6+
import sys
7+
import threading
8+
9+
import ezmsg.core as ez
10+
11+
12+
class CountSettings(ez.Settings):
13+
num_msgs: int = 3
14+
15+
16+
class CountPublisher(ez.Unit):
17+
SETTINGS = CountSettings
18+
OUTPUT = ez.OutputStream(int)
19+
20+
@ez.publisher(OUTPUT)
21+
async def publish(self):
22+
for i in range(self.SETTINGS.num_msgs):
23+
yield self.OUTPUT, i
24+
await asyncio.sleep(0.01)
25+
raise ez.Complete
26+
27+
28+
class CompleteSubscriber(ez.Unit):
29+
SETTINGS = CountSettings
30+
INPUT = ez.InputStream(int)
31+
32+
async def initialize(self) -> None:
33+
self._count = 0
34+
35+
@ez.subscriber(INPUT)
36+
async def on_message(self, _msg: int) -> None:
37+
self._count += 1
38+
if self._count >= self.SETTINGS.num_msgs:
39+
raise ez.Complete
40+
41+
42+
class NormalTerminationSubscriber(ez.Unit):
43+
SETTINGS = CountSettings
44+
INPUT = ez.InputStream(int)
45+
46+
async def initialize(self) -> None:
47+
self._count = 0
48+
49+
@ez.subscriber(INPUT)
50+
async def on_message(self, _msg: int) -> None:
51+
self._count += 1
52+
if self._count >= self.SETTINGS.num_msgs:
53+
raise ez.NormalTermination
54+
55+
56+
class NormalTerminationSubscriberWithThread(ez.Unit):
57+
SETTINGS = CountSettings
58+
INPUT = ez.InputStream(int)
59+
60+
async def initialize(self) -> None:
61+
self._count = 0
62+
self._stop_thread = False
63+
64+
@ez.thread
65+
def background(self) -> None:
66+
while not self._stop_thread:
67+
time.sleep(0.05)
68+
69+
@ez.subscriber(INPUT)
70+
async def on_message(self, _msg: int) -> None:
71+
self._count += 1
72+
if self._count >= self.SETTINGS.num_msgs:
73+
self._stop_thread = True
74+
raise ez.NormalTermination
75+
76+
77+
class InfiniteTask(ez.Unit):
78+
@ez.task
79+
async def run(self) -> None:
80+
while True:
81+
await asyncio.sleep(0.2)
82+
83+
84+
class BaseSystem(ez.Collection):
85+
PUB = CountPublisher()
86+
SUB = CompleteSubscriber()
87+
88+
def configure(self) -> None:
89+
self.PUB.apply_settings(CountSettings())
90+
self.SUB.apply_settings(CountSettings())
91+
92+
def network(self) -> ez.NetworkDefinition:
93+
return ((self.PUB.OUTPUT, self.SUB.INPUT),)
94+
95+
def process_components(self):
96+
return (self.PUB, self.SUB)
97+
98+
99+
class CompleteSystem(BaseSystem):
100+
SUB = CompleteSubscriber()
101+
102+
103+
class NormalTerminationSystem(BaseSystem):
104+
SUB = NormalTerminationSubscriber()
105+
106+
107+
class NormalTerminationThreadSystem(BaseSystem):
108+
SUB = NormalTerminationSubscriberWithThread()
109+
110+
111+
class InfiniteSystem(ez.Collection):
112+
TASK = InfiniteTask()
113+
114+
def process_components(self):
115+
return (self.TASK,)
116+
117+
118+
SYSTEMS = {
119+
"complete": CompleteSystem,
120+
"normalterm": NormalTerminationSystem,
121+
"normalterm_thread": NormalTerminationThreadSystem,
122+
"infinite": InfiniteSystem,
123+
}
124+
125+
126+
def main() -> None:
127+
if os.environ.get("EZMSG_INBAND_SIGINT"):
128+
def _listen() -> None:
129+
for line in sys.stdin:
130+
if line.strip().upper() == "SIGINT":
131+
signal.raise_signal(signal.SIGINT)
132+
133+
threading.Thread(target=_listen, daemon=True).start()
134+
135+
start_method = os.environ.get("EZMSG_MP_START")
136+
if start_method:
137+
mp.set_start_method(start_method, force=True)
138+
case = os.environ.get("EZMSG_SHUTDOWN_EXAMPLE")
139+
if case not in SYSTEMS:
140+
raise SystemExit(
141+
"EZMSG_SHUTDOWN_EXAMPLE must be one of: "
142+
+ ", ".join(sorted(SYSTEMS))
143+
)
144+
system = SYSTEMS[case]()
145+
print("READY", flush=True)
146+
ez.run(SYSTEM=system)
147+
148+
149+
if __name__ == "__main__":
150+
main()

0 commit comments

Comments
 (0)