Skip to content

Commit c00ffa0

Browse files
authored
Add no_message_restart_interval to reconnect when no messages arrive for a while (#182)
1 parent f46a5c2 commit c00ffa0

5 files changed

Lines changed: 99 additions & 5 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ async with stompman.Client(
4141
disconnect_confirmation_timeout=2,
4242
write_retry_attempts=3,
4343
check_server_alive_interval_factor=3,
44+
no_message_restart_interval=datetime.timedelta(hours=1), # None to disable
4445
) as client:
4546
...
4647
```
@@ -149,6 +150,7 @@ stompman takes care of cleaning up resources automatically. When you leave the c
149150
- If multiple servers were provided, stompman will attempt to connect to each one simultaneously and will use the first that succeeds. If all servers fail to connect, an `stompman.FailedAllConnectAttemptsError` will be raised. In normal situation it doesn't need to be handled: tune retry and timeout parameters in `stompman.Client()` to your needs.
150151

151152
- When connection is lost, stompman will attempt to handle it automatically. `stompman.FailedAllConnectAttemptsError` will be raised if all connection attempts fail. `stompman.FailedAllWriteAttemptsError` will be raised if connection succeeds but sending a frame or heartbeat lead to losing connection.
153+
- If no messages are received for `no_message_restart_interval` (defaults to 1 hour), stompman will force a reconnect. Set to `None` to disable.
152154
- To implement health checks, use `stompman.Client.is_alive()` — it will return `True` if everything is OK and `False` if server is not responding.
153155
- `stompman` will write log warnings when connection is lost, after successful reconnection or invalid state during ack/nack.
154156

packages/stompman/stompman/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import asyncio
2+
import time
23
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
34
from contextlib import AsyncExitStack, asynccontextmanager
45
from dataclasses import dataclass, field
6+
from datetime import timedelta
57
from functools import partial
68
from ssl import SSLContext
79
from types import TracebackType
@@ -45,6 +47,8 @@ class Client:
4547
disconnect_confirmation_timeout: int = 2
4648
check_server_alive_interval_factor: int = 3
4749
"""Client will check if server alive `server heartbeat interval` times `interval factor`"""
50+
no_message_restart_interval: timedelta | None = timedelta(hours=1)
51+
"""Force reconnect if no messages received within this interval. None to disable."""
4852

4953
connection_class: type[AbstractConnection] = Connection
5054

@@ -74,6 +78,7 @@ def __post_init__(self) -> None:
7478
read_max_chunk_size=self.read_max_chunk_size,
7579
write_retry_attempts=self.write_retry_attempts,
7680
check_server_alive_interval_factor=self.check_server_alive_interval_factor,
81+
no_message_restart_interval=self.no_message_restart_interval,
7782
ssl=self.ssl,
7883
)
7984

@@ -99,6 +104,7 @@ async def _listen_to_frames(self) -> None:
99104
async for frame in self._connection_manager.read_frames_reconnecting():
100105
match frame:
101106
case MessageFrame():
107+
self._connection_manager._last_message_received_time = time.time()
102108
if subscription := self._active_subscriptions.get_by_id(frame.headers["subscription"]):
103109
task_group.create_task(
104110
subscription._run_handler(frame=frame)

packages/stompman/stompman/connection_manager.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from collections.abc import AsyncGenerator
44
from dataclasses import dataclass, field
5+
from datetime import timedelta
56
from ssl import SSLContext
67
from types import TracebackType
78
from typing import TYPE_CHECKING, Literal, Self
@@ -50,12 +51,15 @@ class ConnectionManager:
5051
read_max_chunk_size: int
5152
write_retry_attempts: int
5253
check_server_alive_interval_factor: int
54+
no_message_restart_interval: timedelta | None
5355

5456
_active_connection_state: ActiveConnectionState | None = field(default=None, init=False)
5557
_reconnect_lock: asyncio.Lock = field(init=False, default_factory=asyncio.Lock)
5658
_task_group: asyncio.TaskGroup = field(init=False, default_factory=asyncio.TaskGroup)
5759
_send_heartbeat_task: asyncio.Task[None] = field(init=False, repr=False)
60+
_monitor_no_message_task: asyncio.Task[None] | None = field(default=None, init=False, repr=False)
5861
_reconnection_count: int = field(default=0, init=False)
62+
_last_message_received_time: float = field(init=False, default_factory=time.time)
5963

6064
async def __aenter__(self) -> Self:
6165
await self._task_group.__aenter__()
@@ -67,7 +71,11 @@ async def __aexit__(
6771
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
6872
) -> None:
6973
self._send_heartbeat_task.cancel()
70-
await asyncio.wait([self._send_heartbeat_task])
74+
tasks = [self._send_heartbeat_task]
75+
if self._monitor_no_message_task is not None:
76+
self._monitor_no_message_task.cancel()
77+
tasks.append(self._monitor_no_message_task)
78+
await asyncio.wait(tasks)
7179
await self._task_group.__aexit__(exc_type, exc_value, traceback)
7280

7381
if not self._active_connection_state:
@@ -78,18 +86,44 @@ async def __aexit__(
7886
return
7987
await self._active_connection_state.connection.close()
8088

81-
def _restart_heartbeat_tasks(self, server_heartbeat: Heartbeat) -> None:
89+
def _restart_background_tasks(self, server_heartbeat: Heartbeat) -> None:
8290
self._send_heartbeat_task.cancel()
8391
self._send_heartbeat_task = self._task_group.create_task(
8492
self._send_heartbeats_forever(server_heartbeat.want_to_receive_interval_ms)
8593
)
8694

95+
def _restart_no_message_monitor(self) -> None:
96+
if self._monitor_no_message_task is not None:
97+
self._monitor_no_message_task.cancel()
98+
if self.no_message_restart_interval is not None:
99+
self._monitor_no_message_task = self._task_group.create_task(
100+
self._monitor_no_message_timeout(self.no_message_restart_interval)
101+
)
102+
87103
async def _send_heartbeats_forever(self, send_heartbeat_interval_ms: int) -> None:
88104
send_heartbeat_interval_seconds = send_heartbeat_interval_ms / 1000
89105
while True:
90106
await self.write_heartbeat_reconnecting()
91107
await asyncio.sleep(send_heartbeat_interval_seconds)
92108

109+
async def _monitor_no_message_timeout(self, interval: timedelta) -> None:
110+
interval_seconds = interval.total_seconds()
111+
while True:
112+
elapsed = time.time() - self._last_message_received_time
113+
if (remaining := interval_seconds - elapsed) > 0:
114+
await asyncio.sleep(remaining)
115+
else:
116+
if connection_state := self._active_connection_state:
117+
LOGGER.warning(
118+
"no messages received for %s seconds, forcing reconnect",
119+
interval_seconds,
120+
)
121+
self._clear_active_connection_state(
122+
ConnectionLostError(reason="no messages received within timeout")
123+
)
124+
await connection_state.connection.close()
125+
await asyncio.sleep(interval_seconds)
126+
93127
async def _create_connection_to_one_server(
94128
self, server: ConnectionParameters
95129
) -> tuple[AbstractConnection, ConnectionParameters] | None:
@@ -121,7 +155,7 @@ async def _connect_to_any_server(self) -> ActiveConnectionState | AnyConnectionI
121155
lifespan = self.lifespan_factory(
122156
connection=connection,
123157
connection_parameters=connection_parameters,
124-
set_heartbeat_interval=self._restart_heartbeat_tasks,
158+
set_heartbeat_interval=self._restart_background_tasks,
125159
)
126160

127161
try:
@@ -152,6 +186,8 @@ async def _get_active_connection_state(self, *, is_initial_call: bool = False) -
152186

153187
if isinstance(connection_result, ActiveConnectionState):
154188
self._active_connection_state = connection_result
189+
self._last_message_received_time = time.time()
190+
self._restart_no_message_monitor()
155191
if not is_initial_call:
156192
LOGGER.warning(
157193
"reconnected after connection failure. connection_parameters: %s",

packages/stompman/test_stompman/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
from collections.abc import AsyncGenerator, Callable
33
from dataclasses import dataclass, field
4+
from datetime import timedelta
45
from ssl import SSLContext
56
from typing import Any, Literal, Self, TypeVar
67

@@ -53,6 +54,7 @@ class EnrichedClient(stompman.Client):
5354
servers: list[stompman.ConnectionParameters] = field(
5455
default_factory=lambda: [stompman.ConnectionParameters("localhost", 12345, "login", "passcode")], kw_only=False
5556
)
57+
no_message_restart_interval: timedelta | None = None
5658

5759

5860
@dataclass(frozen=True, kw_only=True, slots=True)
@@ -80,6 +82,7 @@ class EnrichedConnectionManager(ConnectionManager):
8082
write_retry_attempts: int = 3
8183
ssl: Literal[True] | SSLContext | None = None
8284
check_server_alive_interval_factor: int = 3
85+
no_message_restart_interval: timedelta | None = None
8386

8487

8588
DataclassType = TypeVar("DataclassType")

packages/stompman/test_stompman/test_connection_manager.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
2+
import time
23
from collections.abc import AsyncGenerator, AsyncIterable
4+
from datetime import timedelta
35
from ssl import SSLContext
46
from typing import Literal, Self
57
from unittest import mock
@@ -147,12 +149,12 @@ async def test_get_active_connection_state_lifespan_flaky_ok() -> None:
147149
mock.call(
148150
connection=BaseMockConnection(),
149151
connection_parameters=manager.servers[0],
150-
set_heartbeat_interval=manager._restart_heartbeat_tasks,
152+
set_heartbeat_interval=manager._restart_background_tasks,
151153
),
152154
mock.call(
153155
connection=BaseMockConnection(),
154156
connection_parameters=manager.servers[0],
155-
set_heartbeat_interval=manager._restart_heartbeat_tasks,
157+
set_heartbeat_interval=manager._restart_background_tasks,
156158
),
157159
]
158160

@@ -368,3 +370,48 @@ class MockConnection(BaseMockConnection):
368370
async def test_maybe_write_frame_ok() -> None:
369371
async with EnrichedConnectionManager(connection_class=BaseMockConnection) as manager:
370372
assert await manager.maybe_write_frame(build_dataclass(ConnectFrame))
373+
374+
375+
async def test_no_message_restart_triggers_reconnect(monkeypatch: pytest.MonkeyPatch) -> None:
376+
frozen_time = [time.time()]
377+
monkeypatch.setattr("time.time", lambda: frozen_time[0])
378+
379+
async with EnrichedConnectionManager(
380+
connection_class=BaseMockConnection, no_message_restart_interval=timedelta(seconds=10)
381+
) as manager:
382+
frozen_time[0] += 11
383+
for _ in range(10):
384+
await asyncio.sleep(0)
385+
assert manager._reconnection_count >= 1
386+
old_monitor_task = manager._monitor_no_message_task
387+
await manager._get_active_connection_state()
388+
assert manager._monitor_no_message_task is not old_monitor_task
389+
390+
391+
async def test_no_message_restart_does_not_trigger_when_messages_flow(monkeypatch: pytest.MonkeyPatch) -> None:
392+
frozen_time = [time.time()]
393+
monkeypatch.setattr("time.time", lambda: frozen_time[0])
394+
395+
async with EnrichedConnectionManager(
396+
connection_class=BaseMockConnection, no_message_restart_interval=timedelta(seconds=10)
397+
) as manager:
398+
initial_reconnection_count = manager._reconnection_count
399+
frozen_time[0] += 5
400+
manager._last_message_received_time = frozen_time[0]
401+
for _ in range(10):
402+
await asyncio.sleep(0)
403+
assert manager._reconnection_count == initial_reconnection_count
404+
405+
406+
async def test_no_message_restart_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
407+
frozen_time = [time.time()]
408+
monkeypatch.setattr("time.time", lambda: frozen_time[0])
409+
410+
async with EnrichedConnectionManager(
411+
connection_class=BaseMockConnection, no_message_restart_interval=None
412+
) as manager:
413+
assert manager._monitor_no_message_task is None
414+
frozen_time[0] += 99999
415+
for _ in range(10):
416+
await asyncio.sleep(0)
417+
assert manager._reconnection_count == 0

0 commit comments

Comments
 (0)