-
Notifications
You must be signed in to change notification settings - Fork 76
Expand file tree
/
Copy pathtest_roborock_session.py
More file actions
565 lines (417 loc) · 18 KB
/
test_roborock_session.py
File metadata and controls
565 lines (417 loc) · 18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
"""Tests for the MQTT session module."""
import asyncio
import copy
import datetime
from collections.abc import Callable, Generator
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
import aiomqtt
import pytest
from roborock.diagnostics import Diagnostics
from roborock.mqtt.roborock_session import RoborockMqttSession, create_mqtt_session
from roborock.mqtt.session import MqttSessionException, MqttSessionUnauthorized
from tests import mqtt_packet
from tests.fixtures.mqtt import FAKE_PARAMS, Subscriber
pytest_plugins = [
"tests.fixtures.logging_fixtures",
"tests.fixtures.pahomqtt_fixtures",
"tests.fixtures.aiomqtt_fixtures",
]
@pytest.fixture(autouse=True)
def mqtt_server_fixture(
mock_paho_mqtt_create_connection: None,
mock_paho_mqtt_select: None,
) -> None:
"""Fixture to prepare a fake MQTT server."""
@pytest.fixture(autouse=True)
def auto_mock_aiomqtt_client(
mock_aiomqtt_client: None,
) -> None:
"""Automatically use the mock mqtt client fixture."""
@pytest.fixture(autouse=True)
def auto_fast_backoff(fast_backoff_fixture: None) -> None:
"""Automatically use the fast backoff fixture."""
class FakeAsyncIterator:
"""Fake async iterator that waits for messages to arrive, but they never do.
This is used for testing exceptions in other client functions.
"""
def __init__(self) -> None:
self.loop = True
def __aiter__(self):
return self
async def __anext__(self) -> None:
"""Iterator that does not generate any messages."""
while self.loop:
await asyncio.sleep(0.01)
@pytest.fixture(name="message_iterator")
def message_iterator_fixture() -> FakeAsyncIterator:
"""Fixture to provide a side effect for creating the MQTT client."""
return FakeAsyncIterator()
@pytest.fixture(name="mock_client")
def mock_client_fixture(message_iterator: FakeAsyncIterator) -> Generator[AsyncMock, None, None]:
"""A fixture that provides a mocked aiomqtt Client.
This is lighter weight that `mock_aiomqtt_client` that uses real sockets.
"""
mock_client = AsyncMock()
mock_client.messages = message_iterator
return mock_client
@pytest.fixture(name="create_client_side_effect")
def create_client_side_effect_fixture() -> Exception | None:
"""Fixture to provide a side effect for creating the MQTT client."""
return None
@pytest.fixture(name="mock_aenter_client")
def mock_aenter_client_fixture(mock_client: AsyncMock, create_client_side_effect: Exception | None) -> AsyncMock:
"""Fixture to provide a side effect for creating the MQTT client."""
mock_aenter = AsyncMock()
mock_aenter.return_value = mock_client
mock_aenter.side_effect = create_client_side_effect
return mock_aenter
@pytest.fixture(name="mqtt_client_lite")
def mqtt_client_lite_fixture(
mock_client: AsyncMock,
mock_aenter_client: AsyncMock,
) -> Generator[AsyncMock, None, None]:
"""Fixture to create a mock MQTT client with patched aiomqtt.Client."""
mock_shim = Mock()
mock_shim.return_value.__aenter__ = mock_aenter_client
mock_shim.return_value.__aexit__ = AsyncMock()
with patch("roborock.mqtt.roborock_session.aiomqtt.Client", mock_shim):
yield mock_client
async def test_session(push_mqtt_response: Callable[[bytes], None]) -> None:
"""Test the MQTT session."""
push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))
session = await create_mqtt_session(FAKE_PARAMS)
assert session.connected
push_mqtt_response(mqtt_packet.gen_suback(mid=1))
subscriber1 = Subscriber()
unsub1 = await session.subscribe("topic-1", subscriber1.append)
push_mqtt_response(mqtt_packet.gen_suback(mid=2))
subscriber2 = Subscriber()
await session.subscribe("topic-2", subscriber2.append)
push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345"))
await subscriber1.wait()
assert subscriber1.messages == [b"12345"]
assert not subscriber2.messages
push_mqtt_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890"))
await subscriber2.wait()
assert subscriber2.messages == [b"67890"]
push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=5, payload=b"ABC"))
await subscriber1.wait()
assert subscriber1.messages == [b"12345", b"ABC"]
assert subscriber2.messages == [b"67890"]
# Messages are no longer received after unsubscribing
unsub1()
push_mqtt_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored"))
assert subscriber1.messages == [b"12345", b"ABC"]
assert session.connected
await session.close()
assert not session.connected
async def test_session_no_subscribers(push_mqtt_response: Callable[[bytes], None]) -> None:
"""Test the MQTT session."""
push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))
session = await create_mqtt_session(FAKE_PARAMS)
assert session.connected
await session.close()
assert not session.connected
async def test_publish_command(push_mqtt_response: Callable[[bytes], None]) -> None:
"""Test publishing during an MQTT session."""
push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))
session = await create_mqtt_session(FAKE_PARAMS)
push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345"))
await session.publish("topic-1", message=b"payload")
assert session.connected
await session.close()
assert not session.connected
async def test_publish_failure(mqtt_client_lite: AsyncMock) -> None:
"""Test an MQTT error is received when publishing a message."""
session = await create_mqtt_session(FAKE_PARAMS)
assert session.connected
mqtt_client_lite.publish.side_effect = aiomqtt.MqttError
with pytest.raises(MqttSessionException, match="Error publishing message"):
await session.publish("topic-1", message=b"payload")
await session.close()
async def test_subscribe_failure(mqtt_client_lite: AsyncMock) -> None:
"""Test an MQTT error while subscribing."""
session = await create_mqtt_session(FAKE_PARAMS)
assert session.connected
mqtt_client_lite.subscribe.side_effect = aiomqtt.MqttError
subscriber1 = Subscriber()
with pytest.raises(MqttSessionException, match="Error subscribing to topic"):
await session.subscribe("topic-1", subscriber1.append)
assert not subscriber1.messages
await session.close()
async def test_restart(push_mqtt_response: Callable[[bytes], None]) -> None:
"""Test restarting the MQTT session."""
push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))
session = await create_mqtt_session(FAKE_PARAMS)
assert session.connected
# Subscribe to a topic
push_mqtt_response(mqtt_packet.gen_suback(mid=1))
subscriber = Subscriber()
await session.subscribe("topic-1", subscriber.append)
# Verify we can receive messages
push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=b"12345"))
await subscriber.wait()
assert subscriber.messages == [b"12345"]
# Restart the session.
await session.restart()
# This is a hack where we grab on to the client and wait for it to be
# closed properly and restarted.
while session._client: # type: ignore[attr-defined]
await asyncio.sleep(0.01)
# We need to queue up a new connack for the reconnection
push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))
# And a suback for the resubscription. Since we created a new client,
# the message ID resets to 1.
push_mqtt_response(mqtt_packet.gen_suback(mid=1))
push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=4, payload=b"67890"))
await subscriber.wait()
assert subscriber.messages == [b"12345", b"67890"]
await session.close()
async def test_idle_timeout_resubscribe(mqtt_client_lite: AsyncMock) -> None:
"""Test that resubscribing before idle timeout cancels the unsubscribe."""
# Create session with idle timeout
session = RoborockMqttSession(FAKE_PARAMS, topic_idle_timeout=datetime.timedelta(seconds=5))
await session.start()
assert session.connected
topic = "test/topic"
subscriber1 = Subscriber()
unsub1 = await session.subscribe(topic, subscriber1.append)
# Unsubscribe to start idle timer
unsub1()
# Resubscribe before idle timeout expires (should cancel timer)
subscriber2 = Subscriber()
await session.subscribe(topic, subscriber2.append)
# Give a brief moment for any async operations to complete
await asyncio.sleep(0.01)
# unsubscribe should NOT have been called because we resubscribed
mqtt_client_lite.unsubscribe.assert_not_called()
await session.close()
async def test_idle_timeout_unsubscribe(mqtt_client_lite: AsyncMock) -> None:
"""Test that unsubscribe happens after idle timeout expires."""
# Create session with very short idle timeout for fast test
session = RoborockMqttSession(FAKE_PARAMS, topic_idle_timeout=datetime.timedelta(milliseconds=50))
await session.start()
assert session.connected
topic = "test/topic"
subscriber = Subscriber()
unsub = await session.subscribe(topic, subscriber.append)
# Unsubscribe to start idle timer
unsub()
# Wait for idle timeout plus a small buffer
await asyncio.sleep(0.1)
# unsubscribe should have been called after idle timeout
mqtt_client_lite.unsubscribe.assert_called_once_with(topic)
await session.close()
async def test_idle_timeout_multiple_callbacks(mqtt_client_lite: AsyncMock) -> None:
"""Test that unsubscribe is delayed when multiple subscribers exist."""
# Create session with very short idle timeout for fast test
session = RoborockMqttSession(FAKE_PARAMS, topic_idle_timeout=datetime.timedelta(milliseconds=50))
await session.start()
assert session.connected
topic = "test/topic"
subscriber1 = Subscriber()
subscriber2 = Subscriber()
unsub1 = await session.subscribe(topic, subscriber1.append)
unsub2 = await session.subscribe(topic, subscriber2.append)
# Unsubscribe first callback (should NOT start timer, subscriber2 still active)
unsub1()
# Brief wait to ensure no timer fires
await asyncio.sleep(0.1)
# unsubscribe should NOT have been called because subscriber2 is still active
mqtt_client_lite.unsubscribe.assert_not_called()
# Unsubscribe second callback (NOW timer should start)
unsub2()
# Wait for idle timeout plus a small buffer
await asyncio.sleep(0.1)
# Now unsubscribe should have been called
mqtt_client_lite.unsubscribe.assert_called_once_with(topic)
await session.close()
async def test_subscription_reuse(mqtt_client_lite: AsyncMock) -> None:
"""Test that subscriptions are reused and not duplicated."""
session = RoborockMqttSession(FAKE_PARAMS)
await session.start()
assert session.connected
# 1. First subscription
cb1 = Mock()
unsub1 = await session.subscribe("topic1", cb1)
# Verify subscribe called
mqtt_client_lite.subscribe.assert_called_with("topic1")
mqtt_client_lite.subscribe.reset_mock()
# 2. Second subscription (same topic)
cb2 = Mock()
unsub2 = await session.subscribe("topic1", cb2)
# Verify subscribe NOT called
mqtt_client_lite.subscribe.assert_not_called()
# 3. Unsubscribe one
unsub1()
# Verify unsubscribe NOT called (still have cb2)
mqtt_client_lite.unsubscribe.assert_not_called()
# 4. Unsubscribe second (starts idle timer)
unsub2()
# Verify unsubscribe NOT called yet (idle)
mqtt_client_lite.unsubscribe.assert_not_called()
# 5. Resubscribe during idle
cb3 = Mock()
_ = await session.subscribe("topic1", cb3)
# Verify subscribe NOT called (reused)
mqtt_client_lite.subscribe.assert_not_called()
await session.close()
@pytest.mark.parametrize(
("side_effect", "expected_exception", "match"),
[
(
aiomqtt.MqttError("Connection failed"),
MqttSessionException,
"Error starting MQTT session",
),
(
aiomqtt.MqttCodeError(rc=135),
MqttSessionUnauthorized,
"Authorization error starting MQTT session",
),
(
aiomqtt.MqttCodeError(rc=128),
MqttSessionException,
"Error starting MQTT session",
),
(
ValueError("Unexpected"),
MqttSessionException,
"Unexpected error starting session",
),
],
)
async def test_connect_failure(
side_effect: Exception,
expected_exception: type[Exception],
match: str,
) -> None:
"""Test connection failure with different exceptions."""
mock_aenter = AsyncMock()
mock_aenter.side_effect = side_effect
with patch("roborock.mqtt.roborock_session.aiomqtt.Client.__aenter__", mock_aenter):
with pytest.raises(expected_exception, match=match):
await create_mqtt_session(FAKE_PARAMS)
async def test_diagnostics_data(push_mqtt_response: Callable[[bytes], None]) -> None:
"""Test the MQTT session."""
diagnostics = Diagnostics()
params = copy.deepcopy(FAKE_PARAMS)
params.diagnostics = diagnostics
push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))
session = await create_mqtt_session(params)
assert session.connected
# Verify diagnostics after connection
data = diagnostics.as_dict()
assert data.get("start_attempt") == 1
assert data.get("start_loop") == 1
assert data.get("start_success") == 1
assert data.get("subscribe_count") is None
assert data.get("dispatch_message_count") is None
assert data.get("close") is None
push_mqtt_response(mqtt_packet.gen_suback(mid=1))
subscriber1 = Subscriber()
unsub1 = await session.subscribe("topic-1", subscriber1.append)
push_mqtt_response(mqtt_packet.gen_suback(mid=2))
subscriber2 = Subscriber()
await session.subscribe("topic-2", subscriber2.append)
push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345"))
await subscriber1.wait()
assert subscriber1.messages == [b"12345"]
assert not subscriber2.messages
push_mqtt_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890"))
await subscriber2.wait()
assert subscriber2.messages == [b"67890"]
push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=5, payload=b"ABC"))
await subscriber1.wait()
assert subscriber1.messages == [b"12345", b"ABC"]
assert subscriber2.messages == [b"67890"]
# Verify diagnostics after subscribing and receiving messages
data = diagnostics.as_dict()
assert data.get("start_attempt") == 1
assert data.get("start_loop") == 1
assert data.get("subscribe_count") == 2
assert data.get("dispatch_message_count") == 3
assert data.get("close") is None
# Messages are no longer received after unsubscribing
unsub1()
push_mqtt_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored"))
assert subscriber1.messages == [b"12345", b"ABC"]
assert session.connected
await session.close()
assert not session.connected
# Verify diagnostics after closing session
data = diagnostics.as_dict()
assert data.get("start_attempt") == 1
assert data.get("start_loop") == 1
assert data.get("subscribe_count") == 2
assert data.get("dispatch_message_count") == 3
assert data.get("close") == 1
@pytest.mark.parametrize(
("create_client_side_effect"),
[
# Unauthorized
aiomqtt.MqttCodeError(rc=135),
],
)
async def test_session_unauthorized_hook(mqtt_client_lite: AsyncMock) -> None:
"""Test the MQTT session."""
unauthorized = asyncio.Event()
params = copy.deepcopy(FAKE_PARAMS)
params.unauthorized_hook = unauthorized.set
with pytest.raises(MqttSessionUnauthorized):
await create_mqtt_session(params)
assert unauthorized.is_set()
async def test_session_unauthorized_after_start(
mock_aenter_client: AsyncMock,
message_iterator: FakeAsyncIterator,
mqtt_client_lite: AsyncMock,
push_mqtt_response: Callable[[bytes], None],
) -> None:
"""Test the MQTT session."""
# Configure a hook that is notified of unauthorized errors
unauthorized = asyncio.Event()
params = copy.deepcopy(FAKE_PARAMS)
params.unauthorized_hook = unauthorized.set
# The client will succeed on first connection attempt, then fail with
# unauthorized messages on all future attempts.
request_count = 0
def succeed_then_fail_unauthorized() -> Any:
nonlocal request_count
request_count += 1
if request_count == 1:
return mqtt_client_lite
raise aiomqtt.MqttCodeError(rc=135)
mock_aenter_client.side_effect = succeed_then_fail_unauthorized
# Don't produce messages, just exit and restart to reconnect
message_iterator.loop = False
session = await create_mqtt_session(params)
assert session.connected
# Keep an active subscription so reconnect attempts are not deferred.
await session.subscribe("topic-1", Subscriber().append)
try:
async with asyncio.timeout(10):
assert await unauthorized.wait()
finally:
await session.close()
async def test_session_defers_reconnect_when_idle() -> None:
"""Test that reconnects are deferred when there are no active subscriptions."""
session = RoborockMqttSession(FAKE_PARAMS)
start_future: asyncio.Future[None] = asyncio.Future()
connect_attempts = 0
async def fake_run_connection(start: asyncio.Future[None] | None) -> None:
nonlocal connect_attempts
connect_attempts += 1
if start and not start.done():
start.set_result(None)
with patch.object(session, "_run_connection", side_effect=fake_run_connection):
reconnect_task = asyncio.create_task(session._run_reconnect_loop(start_future))
try:
await start_future
await asyncio.sleep(0.1)
assert connect_attempts == 1
assert session._diagnostics.as_dict().get("reconnect_deferred", 0) >= 1
finally:
session._stop = True
reconnect_task.cancel()
await asyncio.gather(reconnect_task, return_exceptions=True)