Skip to content

Commit 6f49f61

Browse files
committed
feat: Add a CLI for exercising the asyncio MQTT session
1 parent 148a6fa commit 6f49f61

4 files changed

Lines changed: 80 additions & 7 deletions

File tree

roborock/cli.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
from pathlib import Path
66
from typing import Any
7+
import asyncio
78

89
import click
910
from pyshark import FileCapture # type: ignore
@@ -12,8 +13,10 @@
1213

1314
from roborock import RoborockException
1415
from roborock.containers import DeviceData, HomeDataProduct, LoginData
15-
from roborock.protocol import MessageParser
16+
from roborock.protocol import MessageParser, create_mqtt_params
1617
from roborock.util import run_sync
18+
from roborock.mqtt.session import MqttParams, MqttSession
19+
from roborock.mqtt.roborock_session import create_mqtt_session
1720
from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1
1821
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
1922
from roborock.web_api import RoborockApiClient
@@ -45,7 +48,8 @@ def validate(self):
4548
if self._login_data is None:
4649
raise RoborockException("You must login first")
4750

48-
def login_data(self):
51+
def login_data(self) -> LoginData:
52+
"""Get the login data."""
4953
self.validate()
5054
return self._login_data
5155

@@ -79,6 +83,54 @@ async def login(ctx, email, password):
7983
context.update(LoginData(user_data=user_data, email=email))
8084

8185

86+
@click.command()
87+
@click.pass_context
88+
@click.option("--duration", default=10, help="Duration to run the MQTT session in seconds")
89+
@run_sync()
90+
async def session(ctx, duration: int):
91+
context: RoborockContext = ctx.obj
92+
login_data = context.login_data()
93+
94+
# Discovery devices if not already available
95+
if not login_data.home_data:
96+
await _discover(ctx)
97+
login_data = context.login_data()
98+
if not login_data.home_data or not login_data.home_data.devices:
99+
raise RoborockException("Unable to discover devices")
100+
101+
all_devices = login_data.home_data.devices + login_data.home_data.received_devices
102+
click.echo(f"Discovered devices: {', '.join([device.name for device in all_devices])}")
103+
104+
rriot = login_data.user_data.rriot
105+
params = create_mqtt_params(rriot)
106+
107+
mqtt_session = await create_mqtt_session(params)
108+
click.echo("Starting MQTT session...")
109+
if not mqtt_session.connected:
110+
raise RoborockException("Failed to connect to MQTT broker")
111+
112+
def on_message(bytes: bytes):
113+
"""Callback function to handle incoming MQTT messages."""
114+
# Decode the first 20 bytes of the message for display
115+
bytes = bytes[:20]
116+
click.echo(f"Received message: b\"{bytes.decode('utf-8')}...\"")
117+
118+
unsubs = []
119+
for device in all_devices:
120+
device_topic = f"rr/m/o/{rriot.u}/{params.username}/{device.duid}"
121+
unsub = await mqtt_session.subscribe(device_topic, on_message)
122+
unsubs.append(unsub)
123+
124+
click.echo("MQTT session started. Listening for messages...")
125+
await asyncio.sleep(duration)
126+
127+
click.echo("Stopping MQTT session...")
128+
for unsub in unsubs:
129+
unsub()
130+
await mqtt_session.close()
131+
132+
133+
82134
async def _discover(ctx):
83135
context: RoborockContext = ctx.obj
84136
login_data = context.login_data()
@@ -253,6 +305,7 @@ def on_package(packet: Packet):
253305
cli.add_command(status)
254306
cli.add_command(command)
255307
cli.add_command(parser)
308+
cli.add_command(session)
256309

257310

258311
def main():

roborock/containers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from enum import Enum
1010
from typing import Any, NamedTuple, get_args, get_origin
1111

12+
from .mqtt.session import MqttParams
1213
from .code_mappings import (
1314
RoborockCategory,
1415
RoborockCleanType,

roborock/mqtt/roborock_session.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ async def _run_task(self, start_future: asyncio.Future[None] | None) -> None:
116116
_LOGGER.info("MQTT error: %s", err)
117117
except asyncio.CancelledError as err:
118118
if start_future:
119-
_LOGGER.debug("MQTT loop was cancelled")
119+
_LOGGER.debug("MQTT loop was cancelled while starting")
120120
start_future.set_exception(err)
121-
_LOGGER.debug("MQTT loop was cancelled whiel starting")
121+
_LOGGER.debug("MQTT loop was cancelled")
122122
return
123123
# Catch exceptions to avoid crashing the loop
124124
# and to allow the loop to retry.
@@ -171,8 +171,7 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
171171
self._client = None
172172

173173
async def _process_message_loop(self, client: aiomqtt.Client) -> None:
174-
_LOGGER.debug("client=%s", client)
175-
_LOGGER.debug("Processing MQTT messages: %s", client.messages)
174+
_LOGGER.debug("Processing MQTT messages")
176175
async for message in client.messages:
177176
_LOGGER.debug("Received message: %s", message)
178177
for listener in self._listeners.get(message.topic.value, []):

roborock/protocol.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
from asyncio import BaseTransport, Lock
1010
from collections.abc import Callable
11+
from urllib.parse import urlparse
1112

1213
from construct import ( # type: ignore
1314
Bytes,
@@ -30,8 +31,10 @@
3031
from Crypto.Cipher import AES
3132
from Crypto.Util.Padding import pad, unpad
3233

33-
from roborock import BroadcastMessage, RoborockException
34+
from roborock.containers import BroadcastMessage, RRiot
35+
from roborock.exceptions import RoborockException
3436
from roborock.roborock_message import RoborockMessage
37+
from roborock.mqtt.session import MqttParams
3538

3639
_LOGGER = logging.getLogger(__name__)
3740
SALT = b"TXdfu$jyZ#TZHsg4"
@@ -359,3 +362,20 @@ def build(
359362

360363
MessageParser: _Parser = _Parser(_Messages, True)
361364
BroadcastParser: _Parser = _Parser(_BroadcastMessage, False)
365+
366+
367+
368+
def create_mqtt_params(rriot: RRiot) -> MqttParams:
369+
"""Return the MQTT parameters for this user."""
370+
url = urlparse(rriot.r.m)
371+
if not isinstance(url.hostname, str):
372+
raise RoborockException("Url parsing returned an invalid hostname")
373+
hashed_user = md5hex(rriot.u + ":" + rriot.k)[2:10]
374+
hashed_password = md5hex(rriot.s + ":" + rriot.k)[16:]
375+
return MqttParams(
376+
host=str(url.hostname),
377+
port=url.port,
378+
tls=(url.scheme == "ssl"),
379+
username=hashed_user,
380+
password=hashed_password,
381+
)

0 commit comments

Comments
 (0)