Skip to content

Commit 137855d

Browse files
committed
fix: proper return codes
1 parent 7f6ca7b commit 137855d

5 files changed

Lines changed: 202 additions & 20 deletions

File tree

src/ezmsg/core/backend.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import enum
55
import logging
66
import os
7+
import signal
78
from threading import BrokenBarrierError
89
from multiprocessing import Event, Barrier
910
from multiprocessing.synchronize import Event as EventType
@@ -23,6 +24,7 @@
2324
from .backendprocess import (
2425
BackendProcess,
2526
DefaultBackendProcess,
27+
ShutdownSummary,
2628
new_threaded_event_loop,
2729
)
2830

@@ -170,6 +172,7 @@ class GraphRunner:
170172
_graph_context: GraphContext | None
171173
_loop: asyncio.AbstractEventLoop | None
172174
_loop_cm: object | None
175+
_loop_shutdown_summary: ShutdownSummary | None
173176
_main_process: BackendProcess | None
174177
_spawned_processes: list[BackendProcess]
175178
_start_participant: bool
@@ -208,6 +211,7 @@ def __init__(
208211
self._graph_context = None
209212
self._loop = None
210213
self._loop_cm = None
214+
self._loop_shutdown_summary = None
211215
self._main_process = None
212216
self._spawned_processes = []
213217
self._start_participant = False
@@ -320,7 +324,10 @@ def _initialize(self, force_single_process: bool, wait_for_ready: bool) -> bool:
320324
if self._execution_context is None:
321325
return False
322326

323-
self._loop_cm = new_threaded_event_loop()
327+
self._loop_shutdown_summary = ShutdownSummary()
328+
self._loop_cm = new_threaded_event_loop(
329+
shutdown_summary=self._loop_shutdown_summary
330+
)
324331
self._loop = self._loop_cm.__enter__()
325332

326333
try:
@@ -389,12 +396,15 @@ def _run_main_process(self) -> None:
389396
self._main_process = self._execution_context.processes[0]
390397
self._start_processes(self._execution_context.processes[1:])
391398

399+
interrupts = 0
400+
forced_sigint = False
392401
try:
393402
self._main_process.process(self._loop)
394403
self._join_spawned_processes()
395404
logger.info("All processes exited normally")
396405

397406
except KeyboardInterrupt:
407+
interrupts += 1
398408
logger.info(
399409
"Attempting graceful shutdown, interrupt again to force quit..."
400410
)
@@ -404,17 +414,71 @@ def _run_main_process(self) -> None:
404414
self._join_spawned_processes()
405415

406416
except KeyboardInterrupt:
417+
interrupts += 1
418+
forced_sigint = True
407419
logger.warning("Interrupt intercepted, force quitting")
408420
self._execution_context.start_barrier.abort()
409421
self._execution_context.stop_barrier.abort()
410422
for proc in self._spawned_processes:
411423
proc.terminate()
412424

413425
finally:
414-
self._join_spawned_processes()
415-
self._cleanup()
426+
while True:
427+
try:
428+
self._join_spawned_processes()
429+
self._cleanup()
430+
break
431+
except KeyboardInterrupt:
432+
interrupts += 1
433+
if interrupts >= 2:
434+
forced_sigint = True
435+
logger.warning("Interrupt intercepted, force quitting")
436+
if self._execution_context is not None:
437+
self._execution_context.start_barrier.abort()
438+
self._execution_context.stop_barrier.abort()
439+
for proc in self._spawned_processes:
440+
proc.terminate()
441+
self._cleanup()
442+
break
443+
logger.info(
444+
"Interrupt received during cleanup; attempting graceful shutdown..."
445+
)
446+
if self._execution_context is not None:
447+
self._execution_context.term_ev.set()
416448
self._started = False
417449
self._stopped = True
450+
if interrupts and not forced_sigint and self._shutdown_was_unclean():
451+
forced_sigint = True
452+
if forced_sigint:
453+
self._exit_with_sigint()
454+
455+
def _shutdown_was_unclean(self) -> bool:
456+
main_shutdown_errors = bool(
457+
self._main_process is not None
458+
and getattr(self._main_process, "_shutdown_errors", False)
459+
)
460+
summary = self._loop_shutdown_summary
461+
loop_unclean = bool(summary is not None and summary.unclean)
462+
return main_shutdown_errors or loop_unclean
463+
464+
def _exit_with_sigint(self) -> None:
465+
prev_handler = None
466+
try:
467+
prev_handler = signal.getsignal(signal.SIGINT)
468+
signal.signal(signal.SIGINT, signal.SIG_DFL)
469+
signal.raise_signal(signal.SIGINT)
470+
except Exception:
471+
code = 0xC000013A if os.name == "nt" else 130
472+
raise SystemExit(code)
473+
finally:
474+
if prev_handler is not None:
475+
try:
476+
signal.signal(signal.SIGINT, prev_handler)
477+
except Exception:
478+
pass
479+
480+
code = 0xC000013A if os.name == "nt" else 130
481+
raise SystemExit(code)
418482

419483
def _cleanup(self) -> None:
420484
if self._cleanup_done:

src/ezmsg/core/backendprocess.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import weakref
1010

1111
from abc import abstractmethod
12+
from dataclasses import dataclass
1213
from collections import defaultdict
1314
from collections.abc import Callable, Coroutine, Generator, Sequence
1415
from functools import wraps, partial
@@ -39,7 +40,28 @@ def _strict_shutdown_enabled() -> bool:
3940
return value.lower() in ("1", "true", "yes", "on")
4041

4142

43+
@dataclass
44+
class ShutdownSummary:
45+
cancelled_tasks: int = 0
46+
executor_active: int = 0
47+
suppressed_errors: int = 0
48+
forced_interrupt: bool = False
49+
50+
@property
51+
def unclean(self) -> bool:
52+
return bool(
53+
self.executor_active
54+
or self.suppressed_errors
55+
or self.forced_interrupt
56+
)
57+
58+
4259
class _DaemonThreadPoolExecutor(ThreadPoolExecutor):
60+
def __init__(self, *args, **kwargs) -> None:
61+
super().__init__(*args, **kwargs)
62+
self._active_count = 0
63+
self._active_lock = threading.Lock()
64+
4365
def _adjust_thread_count(self) -> None:
4466
if self._broken:
4567
return
@@ -61,6 +83,22 @@ def _adjust_thread_count(self) -> None:
6183
thread.start()
6284
self._threads.add(thread)
6385

86+
def submit(self, fn, /, *args, **kwargs):
87+
fut = super().submit(fn, *args, **kwargs)
88+
with self._active_lock:
89+
self._active_count += 1
90+
91+
def _decrement(_):
92+
with self._active_lock:
93+
self._active_count -= 1
94+
95+
fut.add_done_callback(_decrement)
96+
return fut
97+
98+
def active_count(self) -> int:
99+
with self._active_lock:
100+
return self._active_count
101+
64102

65103
class Complete(Exception):
66104
"""
@@ -178,11 +216,13 @@ class DefaultBackendProcess(BackendProcess):
178216
"""
179217

180218
pubs: dict[str, Publisher]
219+
_shutdown_errors: bool
181220

182221
def process(self, loop: asyncio.AbstractEventLoop) -> None:
183222
main_func = None
184223
context = GraphContext(self.graph_address)
185224
coro_callables: dict[str, Callable[[], Coroutine[Any, Any, None]]] = dict()
225+
self._shutdown_errors = False
186226

187227
try:
188228
self.pubs = dict()
@@ -445,6 +485,8 @@ async def wrapped_task(msg: Any = None) -> None:
445485
except Exception:
446486
logger.error(f"Exception in Task: {task_address}")
447487
logger.error(traceback.format_exc())
488+
if self.term_ev.is_set():
489+
self._shutdown_errors = True
448490
if strict_shutdown:
449491
raise
450492

@@ -522,6 +564,7 @@ def run_loop(loop: asyncio.AbstractEventLoop):
522564
@contextmanager
523565
def new_threaded_event_loop(
524566
ev: threading.Event | None = None,
567+
shutdown_summary: ShutdownSummary | None = None,
525568
) -> Generator[asyncio.AbstractEventLoop, None, None]:
526569
"""
527570
Create a new asyncio event loop running in a separate thread.
@@ -531,6 +574,8 @@ def new_threaded_event_loop(
531574
532575
:param ev: Optional event to signal when the loop is ready.
533576
:type ev: threading.Event | None
577+
:param shutdown_summary: Optional shutdown summary object to populate.
578+
:type shutdown_summary: ShutdownSummary | None
534579
:return: Context manager yielding the event loop.
535580
:rtype: Generator[asyncio.AbstractEventLoop, None, None]
536581
"""
@@ -539,10 +584,10 @@ def new_threaded_event_loop(
539584
shutdown_suppress = threading.Event()
540585
suppressed_shutdown_errors = {"count": 0}
541586
suppressed_lock = threading.Lock()
587+
executor = None
542588
if not strict_shutdown:
543-
loop.set_default_executor(
544-
_DaemonThreadPoolExecutor(thread_name_prefix="EZMSG")
545-
)
589+
executor = _DaemonThreadPoolExecutor(thread_name_prefix="EZMSG")
590+
loop.set_default_executor(executor)
546591
def _loop_exception_handler(
547592
loop_obj: asyncio.AbstractEventLoop, context: dict
548593
) -> None:
@@ -607,6 +652,14 @@ async def _cancel_remaining() -> int:
607652
cancelled_count,
608653
)
609654

655+
if shutdown_summary is not None:
656+
shutdown_summary.cancelled_tasks = cancelled_count
657+
shutdown_summary.executor_active = (
658+
executor.active_count() if executor is not None else 0
659+
)
660+
shutdown_summary.suppressed_errors = suppressed_count
661+
shutdown_summary.forced_interrupt = forced_interrupt
662+
610663
loop.call_soon_threadsafe(loop.stop)
611664
thread.join()
612665
loop.close()

tests/clean_shutdown_examples_runner.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,29 @@ def _listen() -> None:
142142
+ ", ".join(sorted(SYSTEMS))
143143
)
144144
system = SYSTEMS[case]()
145-
print("READY", flush=True)
146-
ez.run(SYSTEM=system)
145+
runner = ez.GraphRunner(SYSTEM=system)
146+
ready_emitted = threading.Event()
147+
done = threading.Event()
148+
149+
def _emit_ready() -> None:
150+
if not ready_emitted.is_set():
151+
print("READY", flush=True)
152+
ready_emitted.set()
153+
154+
def _watch_ready() -> None:
155+
while not done.is_set():
156+
if runner.running:
157+
_emit_ready()
158+
return
159+
time.sleep(0.01)
160+
_emit_ready()
161+
162+
threading.Thread(target=_watch_ready, daemon=True).start()
163+
try:
164+
runner.run_blocking()
165+
finally:
166+
done.set()
167+
_emit_ready()
147168

148169

149170
if __name__ == "__main__":

tests/shutdown_runner.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import socket
44
import threading
5+
import time
56
import signal
67
import sys
78

@@ -73,8 +74,28 @@ def _listen() -> None:
7374
"EZMSG_SHUTDOWN_TEST must be one of: " + ", ".join(sorted(UNITS))
7475
)
7576
runner = ez.GraphRunner(SYSTEM=UNITS[target]())
76-
print("READY", flush=True)
77-
runner.run_blocking()
77+
ready_emitted = threading.Event()
78+
done = threading.Event()
79+
80+
def _emit_ready() -> None:
81+
if not ready_emitted.is_set():
82+
print("READY", flush=True)
83+
ready_emitted.set()
84+
85+
def _watch_ready() -> None:
86+
while not done.is_set():
87+
if runner.running:
88+
_emit_ready()
89+
return
90+
time.sleep(0.01)
91+
_emit_ready()
92+
93+
threading.Thread(target=_watch_ready, daemon=True).start()
94+
try:
95+
runner.run_blocking()
96+
finally:
97+
done.set()
98+
_emit_ready()
7899

79100

80101
if __name__ == "__main__":

0 commit comments

Comments
 (0)