-
Notifications
You must be signed in to change notification settings - Fork 78
Expand file tree
/
Copy pathmqtt_manager.py
More file actions
113 lines (91 loc) · 4.45 KB
/
mqtt_manager.py
File metadata and controls
113 lines (91 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from __future__ import annotations
import asyncio
import dataclasses
import logging
from collections.abc import Coroutine
from typing import Callable, Self
from urllib.parse import urlparse
import aiomqtt
from aiomqtt import TLSParameters
from roborock import RoborockException, UserData
from roborock.protocol import MessageParser, md5hex
from .containers import DeviceData
LOGGER = logging.getLogger(__name__)
@dataclasses.dataclass
class ClientWrapper:
publish_function: Coroutine[None]
unsubscribe_function: Coroutine[None]
subscribe_function: Coroutine[None]
class RoborockMqttManager:
client_wrappers: dict[str, ClientWrapper] = {}
_instance: Self = None
def __new__(cls) -> RoborockMqttManager:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
async def connect(self, user_data: UserData):
# Add some kind of lock so we don't try to connect if we are already trying to connect the same account.
if user_data.rriot.u not in self.client_wrappers:
loop = asyncio.get_event_loop()
loop.create_task(self._new_connect(user_data))
async def _new_connect(self, user_data: UserData):
rriot = user_data.rriot
mqtt_user = rriot.u
hashed_user = md5hex(mqtt_user + ":" + rriot.k)[2:10]
url = urlparse(rriot.r.m)
if not isinstance(url.hostname, str):
raise RoborockException("Url parsing returned an invalid hostname")
mqtt_host = str(url.hostname)
mqtt_port = url.port
mqtt_password = rriot.s
hashed_password = md5hex(mqtt_password + ":" + rriot.k)[16:]
LOGGER.debug("Connecting to %s for %s", mqtt_host, mqtt_user)
async with aiomqtt.Client(
hostname=mqtt_host,
port=mqtt_port,
username=hashed_user,
password=hashed_password,
keepalive=60,
tls_params=TLSParameters(),
) as client:
# TODO: Handle logic for when client loses connection
LOGGER.info("Connected to %s for %s", mqtt_host, mqtt_user)
callbacks: dict[str, Callable] = {}
device_map = {}
async def publish(device: DeviceData, payload: bytes):
await client.publish(f"rr/m/i/{mqtt_user}/{hashed_user}/{device.device.duid}", payload=payload)
async def subscribe(device: DeviceData, callback):
LOGGER.debug(f"Subscribing to rr/m/o/{mqtt_user}/{hashed_user}/{device.device.duid}")
await client.subscribe(f"rr/m/o/{mqtt_user}/{hashed_user}/{device.device.duid}")
LOGGER.debug(f"Subscribed to rr/m/o/{mqtt_user}/{hashed_user}/{device.device.duid}")
callbacks[device.device.duid] = callback
device_map[device.device.duid] = device
return
async def unsubscribe(device: DeviceData):
await client.unsubscribe(f"rr/m/o/{mqtt_user}/{hashed_user}/{device.device.duid}")
self.client_wrappers[user_data.rriot.u] = ClientWrapper(
publish_function=publish, unsubscribe_function=unsubscribe, subscribe_function=subscribe
)
async for message in client.messages:
try:
device_id = message.topic.value.split("/")[-1]
device = device_map[device_id]
message = MessageParser.parse(message.payload, device.device.local_key)
callbacks[device_id](message)
except Exception:
...
async def disconnect(self, user_data: UserData):
await self.client_wrappers[user_data.rriot.u].disconnect()
async def subscribe(self, user_data: UserData, device: DeviceData, callback):
if user_data.rriot.u not in self.client_wrappers:
await self.connect(user_data)
# add some kind of lock to make sure we don't subscribe until the connection is successful
await asyncio.sleep(2)
await self.client_wrappers[user_data.rriot.u].subscribe_function(device, callback)
async def unsubscribe(self):
pass
async def publish(self, user_data: UserData, device, payload: bytes):
LOGGER.debug("Publishing topic for %s, Message: %s", device.device.duid, payload)
if user_data.rriot.u not in self.client_wrappers:
await self.connect(user_data)
await self.client_wrappers[user_data.rriot.u].publish_function(device, payload)