|
4 | 4 | import logging |
5 | 5 | from pathlib import Path |
6 | 6 | from typing import Any |
| 7 | +import asyncio |
7 | 8 |
|
8 | 9 | import click |
9 | 10 | from pyshark import FileCapture # type: ignore |
|
12 | 13 |
|
13 | 14 | from roborock import RoborockException |
14 | 15 | from roborock.containers import DeviceData, HomeDataProduct, LoginData |
15 | | -from roborock.protocol import MessageParser |
| 16 | +from roborock.protocol import MessageParser, create_mqtt_params |
16 | 17 | from roborock.util import run_sync |
| 18 | +from roborock.mqtt.session import MqttParams, MqttSession |
| 19 | +from roborock.mqtt.roborock_session import create_mqtt_session |
17 | 20 | from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1 |
18 | 21 | from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1 |
19 | 22 | from roborock.web_api import RoborockApiClient |
@@ -45,7 +48,8 @@ def validate(self): |
45 | 48 | if self._login_data is None: |
46 | 49 | raise RoborockException("You must login first") |
47 | 50 |
|
48 | | - def login_data(self): |
| 51 | + def login_data(self) -> LoginData: |
| 52 | + """Get the login data.""" |
49 | 53 | self.validate() |
50 | 54 | return self._login_data |
51 | 55 |
|
@@ -79,6 +83,54 @@ async def login(ctx, email, password): |
79 | 83 | context.update(LoginData(user_data=user_data, email=email)) |
80 | 84 |
|
81 | 85 |
|
| 86 | +@click.command() |
| 87 | +@click.pass_context |
| 88 | +@click.option("--duration", default=10, help="Duration to run the MQTT session in seconds") |
| 89 | +@run_sync() |
| 90 | +async def session(ctx, duration: int): |
| 91 | + context: RoborockContext = ctx.obj |
| 92 | + login_data = context.login_data() |
| 93 | + |
| 94 | + # Discovery devices if not already available |
| 95 | + if not login_data.home_data: |
| 96 | + await _discover(ctx) |
| 97 | + login_data = context.login_data() |
| 98 | + if not login_data.home_data or not login_data.home_data.devices: |
| 99 | + raise RoborockException("Unable to discover devices") |
| 100 | + |
| 101 | + all_devices = login_data.home_data.devices + login_data.home_data.received_devices |
| 102 | + click.echo(f"Discovered devices: {', '.join([device.name for device in all_devices])}") |
| 103 | + |
| 104 | + rriot = login_data.user_data.rriot |
| 105 | + params = create_mqtt_params(rriot) |
| 106 | + |
| 107 | + mqtt_session = await create_mqtt_session(params) |
| 108 | + click.echo("Starting MQTT session...") |
| 109 | + if not mqtt_session.connected: |
| 110 | + raise RoborockException("Failed to connect to MQTT broker") |
| 111 | + |
| 112 | + def on_message(bytes: bytes): |
| 113 | + """Callback function to handle incoming MQTT messages.""" |
| 114 | + # Decode the first 20 bytes of the message for display |
| 115 | + bytes = bytes[:20] |
| 116 | + click.echo(f"Received message: b\"{bytes.decode('utf-8')}...\"") |
| 117 | + |
| 118 | + unsubs = [] |
| 119 | + for device in all_devices: |
| 120 | + device_topic = f"rr/m/o/{rriot.u}/{params.username}/{device.duid}" |
| 121 | + unsub = await mqtt_session.subscribe(device_topic, on_message) |
| 122 | + unsubs.append(unsub) |
| 123 | + |
| 124 | + click.echo("MQTT session started. Listening for messages...") |
| 125 | + await asyncio.sleep(duration) |
| 126 | + |
| 127 | + click.echo("Stopping MQTT session...") |
| 128 | + for unsub in unsubs: |
| 129 | + unsub() |
| 130 | + await mqtt_session.close() |
| 131 | + |
| 132 | + |
| 133 | + |
82 | 134 | async def _discover(ctx): |
83 | 135 | context: RoborockContext = ctx.obj |
84 | 136 | login_data = context.login_data() |
@@ -253,6 +305,7 @@ def on_package(packet: Packet): |
253 | 305 | cli.add_command(status) |
254 | 306 | cli.add_command(command) |
255 | 307 | cli.add_command(parser) |
| 308 | +cli.add_command(session) |
256 | 309 |
|
257 | 310 |
|
258 | 311 | def main(): |
|
0 commit comments