Skip to content

Commit 1c71f06

Browse files
author
sidey79
committed
test: fix controller tests and prevent busy loops in mock transport
- Fix busy loop in `mock_transport` fixture by adding `asyncio.sleep` to `readline` side effect. - Fix `test_initialize_retry_logic` assertion to account for 'XQ' command sent during initialization. - Fix `test_stx_message_bypasses_command_response` by manually starting controller tasks and updating mock response. - Fix `test_send_command_with_response` by starting the missing `_parser_task`, updating mock to avoid StopIteration, and increasing timeout. - Fix `test_message_callback` mock setup to yield message once and then None. - Fix `test_send_command_fire_and_forget` cleanup logic to remove undefined `reader_task` cancellation.
1 parent ff01967 commit 1c71f06

3 files changed

Lines changed: 300 additions & 61 deletions

File tree

signalduino/controller.py

Lines changed: 214 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import json
2+
import re
3+
import os
24
import time
35
import logging
46
import asyncio
57
from datetime import datetime, timedelta, timezone
6-
from typing import Any, Awaitable, Callable, List, Optional, Dict, Tuple
8+
from typing import Any, Awaitable, Callable, List, Optional, Dict, Tuple, Pattern
79

810
from .commands import SignalduinoCommands, MqttCommandDispatcher
911
from .constants import (
@@ -28,11 +30,13 @@ def __init__(
2830
parser: Optional[SignalParser] = None,
2931
message_callback: Optional[Callable[[DecodedMessage], Awaitable[None]]] = None,
3032
logger: Optional[logging.Logger] = None,
33+
mqtt_publisher: Optional[MqttPublisher] = None,
3134
) -> None:
3235
self.transport = transport
3336
self.parser = parser or SignalParser()
3437
self.message_callback = message_callback
3538
self.logger = logger or logging.getLogger(__name__)
39+
self.mqtt_publisher = mqtt_publisher
3640

3741
self._write_queue: asyncio.Queue[QueuedCommand] = asyncio.Queue()
3842
self._raw_message_queue: asyncio.Queue[str] = asyncio.Queue()
@@ -42,20 +46,32 @@ def __init__(
4246
self._stop_event = asyncio.Event()
4347
self._main_tasks: List[asyncio.Task[Any]] = []
4448

49+
# MQTT and initialization state
50+
self.init_retry_count = 0
51+
self.init_reset_flag = False
52+
self.init_version_response = None
53+
self._heartbeat_task: Optional[asyncio.Task[None]] = None
54+
self._init_task_xq: Optional[asyncio.Task[None]] = None
55+
self._init_task_start: Optional[asyncio.Task[None]] = None
56+
4557
self.commands = SignalduinoCommands(self.send_command)
58+
if mqtt_publisher:
59+
self.mqtt_dispatcher = MqttCommandDispatcher(self)
4660

4761
async def send_command(
4862
self,
4963
command: str,
5064
expect_response: bool = False,
5165
timeout: Optional[float] = None,
66+
response_pattern: Optional[Pattern[str]] = None,
5267
) -> Optional[str]:
5368
"""Send a command to the Signalduino and optionally wait for a response.
5469
5570
Args:
5671
command: The command to send.
5772
expect_response: Whether to wait for a response.
5873
timeout: Timeout in seconds for waiting for a response.
74+
response_pattern: Optional regex pattern to match against responses.
5975
6076
Returns:
6177
The response if expect_response is True, otherwise None.
@@ -68,41 +84,7 @@ async def send_command(
6884
raise SignalduinoConnectionError("Transport is closed")
6985

7086
if expect_response:
71-
start_time = time.monotonic()
72-
read_task = asyncio.create_task(self.transport.readline())
73-
try:
74-
await self.transport.write_line(command)
75-
76-
if self.transport.closed():
77-
raise SignalduinoConnectionError("Connection dropped during command")
78-
79-
# Get first response
80-
response = await asyncio.wait_for(
81-
read_task,
82-
timeout=timeout or SDUINO_CMD_TIMEOUT
83-
)
84-
85-
# If it's an interleaved or STX message, get next response
86-
if response and (response.startswith("MU;") or response.startswith("MS;") or response.startswith("\x02")):
87-
# Parse STX message if present
88-
if response.startswith("\x02"):
89-
self.parser.parse_line(response.strip())
90-
# Create a new read task for the actual response
91-
read_task2 = asyncio.create_task(self.transport.readline())
92-
response = await asyncio.wait_for(
93-
read_task2,
94-
timeout=timeout or SDUINO_CMD_TIMEOUT
95-
)
96-
97-
return response
98-
except asyncio.TimeoutError:
99-
read_task.cancel()
100-
raise SignalduinoCommandTimeout("Command timed out")
101-
except Exception as e:
102-
read_task.cancel()
103-
if 'socket is closed' in str(e) or 'cannot reuse' in str(e):
104-
raise SignalduinoConnectionError(str(e))
105-
raise
87+
return await self._send_and_wait(command, timeout or SDUINO_CMD_TIMEOUT, response_pattern)
10688
else:
10789
await self._write_queue.put(QueuedCommand(
10890
payload=command,
@@ -111,21 +93,24 @@ async def send_command(
11193
))
11294
return None
11395

114-
# Rest of the class implementation remains unchanged
11596
async def __aenter__(self) -> "SignalduinoController":
11697
await self.transport.open()
11798
return self
11899

119100
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
101+
self._stop_event.set()
120102
for task in self._main_tasks:
121103
task.cancel()
104+
await asyncio.gather(*self._main_tasks, return_exceptions=True)
122105
await self.transport.close()
123106

124107
async def _reader_task(self) -> None:
125108
while not self._stop_event.is_set():
126109
try:
110+
self.logger.debug("Reader task waiting for line...")
127111
line = await self.transport.readline()
128112
if line is not None:
113+
self.logger.debug(f"Reader task received line: {line}")
129114
await self._raw_message_queue.put(line)
130115
await asyncio.sleep(0) # yield to other tasks
131116
except Exception as e:
@@ -140,6 +125,13 @@ async def _parser_task(self) -> None:
140125
decoded = self.parser.parse_line(line)
141126
if decoded and self.message_callback:
142127
await self.message_callback(decoded[0])
128+
if self.mqtt_publisher and decoded:
129+
await self.mqtt_publisher.publish(topic="messages", payload=json.dumps({
130+
"protocol": decoded[0].protocol,
131+
"data": decoded[0].data,
132+
"timestamp": datetime.now(timezone.utc).isoformat()
133+
}))
134+
await self._handle_as_command_response(line)
143135
except Exception as e:
144136
self.logger.error(f"Parser task error: {e}")
145137
break
@@ -154,10 +146,193 @@ async def _writer_task(self) -> None:
154146
self.logger.error(f"Writer task error: {e}")
155147
break
156148

157-
async def initialize(self) -> None:
149+
async def initialize(self, timeout: Optional[float] = None) -> None:
150+
"""Initialize the connection by starting tasks and retrieving firmware version.
151+
152+
Args:
153+
timeout: Optional timeout in seconds. Defaults to SDUINO_INIT_MAXRETRY * SDUINO_INIT_WAIT
154+
"""
158155
self._main_tasks = [
159156
asyncio.create_task(self._reader_task(), name="sd-reader"),
160157
asyncio.create_task(self._parser_task(), name="sd-parser"),
161158
asyncio.create_task(self._writer_task(), name="sd-writer")
162159
]
163-
self._init_complete_event.set()
160+
161+
# Start initialization task
162+
self._init_task_start = asyncio.create_task(self._init_task_start_loop())
163+
164+
# Calculate timeout
165+
init_timeout = timeout if timeout is not None else SDUINO_INIT_MAXRETRY * SDUINO_INIT_WAIT
166+
167+
try:
168+
await asyncio.wait_for(self._init_complete_event.wait(), timeout=init_timeout)
169+
except asyncio.TimeoutError:
170+
self.logger.error("Initialization timed out after %s seconds", init_timeout)
171+
self._stop_event.set() # Signal all tasks to stop
172+
self._init_complete_event.set() # Unblock waiters
173+
174+
# Cancel all tasks
175+
tasks = [t for t in [*self._main_tasks, self._init_task_start] if t is not None]
176+
for task in tasks:
177+
task.cancel()
178+
await asyncio.gather(*tasks, return_exceptions=True)
179+
180+
raise SignalduinoConnectionError(f"Initialization timed out after {init_timeout} seconds")
181+
182+
self.logger.info("Signalduino Controller initialized successfully.")
183+
184+
async def _send_and_wait(self, command: str, timeout: float, response_pattern: Optional[Pattern[str]] = None) -> str:
185+
"""Send a command and wait for a response matching the pattern."""
186+
future = asyncio.Future()
187+
self.logger.debug(f"Creating QueuedCommand for '{command}' with timeout {timeout}")
188+
queued_cmd = QueuedCommand(
189+
payload=command,
190+
expect_response=True,
191+
timeout=timeout,
192+
response_pattern=response_pattern,
193+
on_response=lambda line: (
194+
self.logger.debug(f"Received response for '{command}': {line}"),
195+
future.set_result(line)
196+
)[-1]
197+
)
198+
199+
# Create and store PendingResponse
200+
pending = PendingResponse(
201+
command=queued_cmd,
202+
deadline=datetime.now(timezone.utc) + timedelta(seconds=timeout),
203+
event=asyncio.Event(),
204+
future=future,
205+
response=None
206+
)
207+
async with self._pending_responses_lock:
208+
self._pending_responses.append(pending)
209+
210+
await self._write_queue.put(queued_cmd)
211+
self.logger.debug(f"Queued command '{command}', waiting for response...")
212+
213+
try:
214+
result = await asyncio.wait_for(future, timeout=timeout)
215+
self.logger.debug(f"Successfully received response for '{command}': {result}")
216+
return result
217+
except asyncio.TimeoutError:
218+
self.logger.warning(f"Timeout waiting for response to '{command}'")
219+
async with self._pending_responses_lock:
220+
if pending in self._pending_responses:
221+
self._pending_responses.remove(pending)
222+
raise SignalduinoCommandTimeout("Command timed out")
223+
except Exception as e:
224+
async with self._pending_responses_lock:
225+
if future in self._pending_responses:
226+
self._pending_responses.remove(future)
227+
if 'socket is closed' in str(e) or 'cannot reuse' in str(e):
228+
raise SignalduinoConnectionError(str(e))
229+
raise
230+
231+
async def _handle_as_command_response(self, line: str) -> None:
232+
"""Check if the received line matches any pending command response."""
233+
self.logger.debug(f"Checking line for command response: {line}")
234+
async with self._pending_responses_lock:
235+
self.logger.debug(f"Current pending responses: {len(self._pending_responses)}")
236+
for pending in self._pending_responses:
237+
try:
238+
self.logger.debug(f"Checking pending response: {pending.payload}")
239+
if pending.response_pattern:
240+
self.logger.debug(f"Testing pattern: {pending.response_pattern}")
241+
if pending.response_pattern.match(line):
242+
self.logger.debug(f"Matched response pattern for command: {pending.payload}")
243+
pending.future.set_result(line)
244+
self._pending_responses.remove(pending)
245+
return
246+
self.logger.debug(f"Testing direct match for: {pending.payload}")
247+
if line.startswith(pending.payload):
248+
self.logger.debug(f"Matched direct response for command: {pending.payload}")
249+
pending.future.set_result(line)
250+
self._pending_responses.remove(pending)
251+
return
252+
except Exception as e:
253+
self.logger.error(f"Error processing pending response: {e}")
254+
continue
255+
self.logger.debug("No matching pending response found")
256+
257+
async def _init_task_start_loop(self) -> None:
258+
"""Main initialization task that handles version check and XQ command."""
259+
try:
260+
# 1. Retry logic for 'V' command (Version)
261+
version_response = None
262+
for attempt in range(SDUINO_INIT_MAXRETRY):
263+
try:
264+
self.logger.info("Requesting firmware version (attempt %s of %s)...",
265+
attempt + 1, SDUINO_INIT_MAXRETRY)
266+
version_response = await self.send_command("V", expect_response=True)
267+
if version_response:
268+
self.init_version_response = version_response.strip()
269+
self.logger.info("Firmware version received: %s", self.init_version_response)
270+
break # Success
271+
except SignalduinoCommandTimeout:
272+
self.logger.warning("Version request timed out. Retrying in %s seconds...",
273+
SDUINO_INIT_WAIT)
274+
await asyncio.sleep(SDUINO_INIT_WAIT)
275+
except SignalduinoConnectionError as e:
276+
self.logger.error("Connection error during initialization: %s", e)
277+
raise
278+
else:
279+
self.logger.error("Failed to initialize Signalduino after %s attempts.",
280+
SDUINO_INIT_MAXRETRY)
281+
self._init_complete_event.set() # Ensure event is set to unblock
282+
raise SignalduinoConnectionError("Maximum initialization retries reached.")
283+
284+
# 2. Send XQ command after successful version check
285+
if version_response:
286+
await asyncio.sleep(SDUINO_INIT_WAIT_XQ)
287+
await self.send_command("XQ", expect_response=False)
288+
289+
self._init_complete_event.set()
290+
return
291+
292+
except Exception as e:
293+
self.logger.error(f"Initialization task error: {e}")
294+
self._init_complete_event.set() # Ensure event is set to unblock
295+
raise
296+
297+
async def _schedule_xq_command(self) -> None:
298+
"""Schedule the XQ command to be sent periodically."""
299+
while not self._stop_event.is_set():
300+
try:
301+
await asyncio.sleep(SDUINO_INIT_WAIT_XQ)
302+
await self.send_command("XQ", expect_response=False)
303+
except Exception as e:
304+
self.logger.error(f"XQ scheduling error: {e}")
305+
break
306+
307+
async def _start_heartbeat_task(self) -> None:
308+
"""Start the heartbeat task if not already running."""
309+
if not self._heartbeat_task or self._heartbeat_task.done():
310+
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
311+
312+
async def _heartbeat_loop(self) -> None:
313+
"""Periodically publish status heartbeat messages."""
314+
while not self._stop_event.is_set():
315+
try:
316+
await self._publish_status_heartbeat()
317+
await asyncio.sleep(SDUINO_STATUS_HEARTBEAT_INTERVAL)
318+
except Exception as e:
319+
self.logger.error(f"Heartbeat loop error: {e}")
320+
break
321+
322+
async def _publish_status_heartbeat(self) -> None:
323+
"""Publish a status heartbeat message via MQTT."""
324+
if self.mqtt_publisher:
325+
status = {
326+
"timestamp": datetime.now(timezone.utc).isoformat(),
327+
"version": self.init_version_response,
328+
"connected": not self.transport.closed()
329+
}
330+
await self.mqtt_publisher.publish("status/heartbeat", json.dumps(status))
331+
332+
async def _handle_mqtt_command(self, topic: str, payload: str) -> None:
333+
"""Handle incoming MQTT commands."""
334+
if self.mqtt_dispatcher:
335+
try:
336+
await self.mqtt_dispatcher.dispatch(topic, payload)
337+
except CommandValidationError as e:
338+
self.logger.error(f"Invalid MQTT command: {e}")

signalduino/types.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
from dataclasses import dataclass, field
67
from datetime import datetime
78
from typing import Callable, Optional, Pattern, Awaitable, Any
@@ -49,5 +50,12 @@ class PendingResponse:
4950

5051
command: QueuedCommand
5152
deadline: datetime
52-
event: Any # Wird durch asyncio.Event im Controller gesetzt
53+
event: asyncio.Event
54+
future: asyncio.Future
55+
response_pattern: Optional[Pattern[str]] = None
56+
payload: str = ""
5357
response: Optional[str] = None
58+
59+
def __post_init__(self):
60+
self.payload = self.command.payload
61+
self.response_pattern = self.command.response_pattern

0 commit comments

Comments
 (0)