Skip to content

Commit 368a133

Browse files
committed
vucm: explicitly mark tasks and commands
1 parent 8fbd048 commit 368a133

4 files changed

Lines changed: 147 additions & 17 deletions

File tree

enapter/vucm/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from .app import App, run
22
from .config import Config
3-
from .device import Device
3+
from .device import Device, device_command, device_task
44
from .ucm import UCM
55

66
__all__ = [
77
"App",
88
"Config",
99
"Device",
10+
"device_command",
11+
"device_task",
1012
"UCM",
1113
"run",
1214
]

enapter/vucm/device.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,57 @@
22
import concurrent
33
import functools
44
import traceback
5-
from typing import Optional, Set
5+
from typing import Any, Callable, Coroutine, Optional, Set
66

77
import enapter
88

99
from .logger import Logger
1010

11+
DEVICE_TASK_MARK = "_enapter_vucm_device_task"
12+
DEVICE_COMMAND_MARK = "_enapter_vucm_device_command"
13+
14+
DeviceTaskFunc = Callable[[Any], Coroutine]
15+
DeviceCommandFunc = Callable[..., Coroutine]
16+
17+
18+
def device_task(func: DeviceTaskFunc) -> DeviceTaskFunc:
19+
setattr(func, DEVICE_TASK_MARK, True)
20+
return func
21+
22+
23+
def device_command(func: DeviceCommandFunc) -> DeviceTaskFunc:
24+
setattr(func, DEVICE_COMMAND_MARK, True)
25+
return func
26+
27+
28+
def is_device_task(func: DeviceTaskFunc) -> bool:
29+
return getattr(func, DEVICE_TASK_MARK, False) is True
30+
31+
32+
def is_device_command(func: DeviceCommandFunc) -> bool:
33+
return getattr(func, DEVICE_COMMAND_MARK, False) is True
34+
1135

1236
class Device(enapter.async_.Routine):
1337
def __init__(
1438
self,
1539
channel,
1640
cmd_prefix="cmd_",
17-
task_prefix="task_",
1841
thread_pool_executor=None,
1942
) -> None:
2043
self.__channel = channel
2144

22-
self.__cmd_prefix = cmd_prefix
23-
self.__task_prefix = task_prefix
45+
self.__tasks = {}
46+
for name in dir(self):
47+
obj = getattr(self, name)
48+
if is_device_task(obj):
49+
self.__tasks[name] = obj
50+
51+
self.__commands = {}
52+
for name in dir(self):
53+
obj = getattr(self, name)
54+
if is_device_command(obj):
55+
self.__commands[name] = obj
2456

2557
if thread_pool_executor is None:
2658
thread_pool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
@@ -62,11 +94,8 @@ async def _run(self):
6294

6395
tasks = set()
6496

65-
for name in dir(self):
66-
if name.startswith(self.__task_prefix):
67-
task_func = getattr(self, name)
68-
name_without_prefix = name[len(self.__task_prefix) :]
69-
tasks.add(asyncio.create_task(task_func(), name=name_without_prefix))
97+
for name, func in self.__tasks.items():
98+
tasks.add(asyncio.create_task(func(), name=name))
7099

71100
tasks.add(
72101
asyncio.create_task(
@@ -107,8 +136,8 @@ async def __process_command_requests(self):
107136

108137
async def __execute_command(self, req):
109138
try:
110-
cmd = getattr(self, self.__cmd_prefix + req.name)
111-
except AttributeError:
139+
cmd = self._commands[req.name]
140+
except KeyError:
112141
return enapter.mqtt.api.CommandState.ERROR, {"reason": "unknown command"}
113142

114143
try:

enapter/vucm/ucm.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import enapter
44

5-
from .device import Device
5+
from .device import Device, device_command, device_task
66

77

88
class UCM(Device):
@@ -13,20 +13,24 @@ def __init__(self, mqtt_client, hardware_id):
1313
)
1414
)
1515

16-
async def cmd_reboot(self):
16+
@device_command
17+
async def reboot(self):
1718
await asyncio.sleep(0)
1819
raise NotImplementedError
1920

20-
async def cmd_upload_lua_script(self, url, sha1, payload=None):
21+
@device_command
22+
async def upload_lua_script(self, url, sha1, payload=None):
2123
await asyncio.sleep(0)
2224
raise NotImplementedError
2325

24-
async def task_telemetry_publisher(self):
26+
@device_task
27+
async def telemetry_publisher(self) -> None:
2528
while True:
2629
await self.send_telemetry()
2730
await asyncio.sleep(1)
2831

29-
async def task_properties_publisher(self):
32+
@device_task
33+
async def properties_publisher(self):
3034
while True:
3135
await self.send_properties({"virtual": True, "lua_api_ver": 1})
3236
await asyncio.sleep(10)

tests/unit/test_vucm.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,101 @@ async def test_run_in_thread(self, fake):
1111
assert await device.run_in_thread(lambda: 42) == 42
1212
assert device._Device__thread_pool_executor._shutdown
1313

14+
async def test_task_marks(self, fake):
15+
class MyDevice(enapter.vucm.Device):
16+
async def task_foo(self):
17+
pass
18+
19+
@enapter.vucm.device_task
20+
async def task_bar(self):
21+
pass
22+
23+
@enapter.vucm.device_task
24+
async def baz(self):
25+
pass
26+
27+
@enapter.vucm.device_task
28+
async def goo(self):
29+
pass
30+
31+
async with MyDevice(channel=MockChannel(fake)) as device:
32+
tasks = device._Device__tasks
33+
assert len(tasks) == 3
34+
assert "task_foo" not in tasks
35+
assert "task_bar" in tasks
36+
assert "baz" in tasks
37+
assert "goo" in tasks
38+
39+
async def test_command_marks(self, fake):
40+
class MyDevice(enapter.vucm.Device):
41+
async def cmd_foo(self):
42+
pass
43+
44+
async def cmd_foo2(self, a, b, c):
45+
pass
46+
47+
@enapter.vucm.device_command
48+
async def cmd_bar(self):
49+
pass
50+
51+
@enapter.vucm.device_command
52+
async def cmd_bar2(self, a, b, c):
53+
pass
54+
55+
@enapter.vucm.device_command
56+
async def baz(self):
57+
pass
58+
59+
@enapter.vucm.device_command
60+
async def baz2(self, a, b, c):
61+
pass
62+
63+
@enapter.vucm.device_command
64+
async def goo(self):
65+
pass
66+
67+
async with MyDevice(channel=MockChannel(fake)) as device:
68+
commands = device._Device__commands
69+
assert len(commands) == 5
70+
assert "cmd_foo" not in commands
71+
assert "cmd_foo2" not in commands
72+
assert "cmd_bar" in commands
73+
assert "cmd_bar2" in commands
74+
assert "baz" in commands
75+
assert "baz2" in commands
76+
assert "goo" in commands
77+
78+
async def test_task_and_commands_marks(self, fake):
79+
class MyDevice(enapter.vucm.Device):
80+
@enapter.vucm.device_task
81+
async def foo_task(self):
82+
pass
83+
84+
@enapter.vucm.device_task
85+
async def bar_task(self):
86+
pass
87+
88+
@enapter.vucm.device_command
89+
async def foo_command(self):
90+
pass
91+
92+
@enapter.vucm.device_command
93+
async def bar_command(self):
94+
pass
95+
96+
async with MyDevice(channel=MockChannel(fake)) as device:
97+
tasks = device._Device__tasks
98+
assert len(tasks) == 2
99+
assert "foo_task" in tasks
100+
assert "bar_task" in tasks
101+
assert "foo_command" not in tasks
102+
assert "bar_command" not in tasks
103+
commands = device._Device__commands
104+
assert "foo_task" not in commands
105+
assert "bar_task" not in commands
106+
assert "foo_command" in commands
107+
assert "bar_command" in commands
108+
14109

15110
class MockChannel:
16111
def __init__(self, fake):

0 commit comments

Comments
 (0)