Skip to content

Commit 574a28e

Browse files
committed
feat: pylance basic type
fix: websockets.ConnectionClosed no await style: isort
1 parent 51a8bca commit 574a28e

11 files changed

Lines changed: 78 additions & 57 deletions

File tree

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
"tests"
44
],
55
"python.testing.unittestEnabled": false,
6-
"python.testing.pytestEnabled": true
6+
"python.testing.pytestEnabled": true,
7+
"python.analysis.typeCheckingMode": "basic"
78
}

dglabv3/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .dglab import dglabv3 # noqa: F401
2-
from .dtype import Channel, StrengthType, Button, Strength # noqa: F401
3-
from .waves import PULSES, ALL_PULSES, Pulse # noqa: F401
2+
from .dtype import Button, Channel, Strength, StrengthType # noqa: F401
3+
from .waves import ALL_PULSES, PULSES, Pulse # noqa: F401

dglabv3/dglab.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1+
import asyncio
2+
import io
13
import json
24
import logging
3-
import qrcode
4-
import io
5-
import asyncio
65
from threading import Event
7-
from websockets.asyncio.client import connect as ws_connect
6+
from typing import Optional
7+
8+
import qrcode
89
import websockets
10+
from websockets.asyncio.client import connect as ws_connect
911

10-
from dglabv3.dtype import Button, Channel, StrengthType, StrengthMode, MessageType, ChannelStrength, Strength
11-
from dglabv3.wsmessage import WSMessage, WStype
12+
from dglabv3.dtype import (Button, Channel, ChannelStrength, MessageType,
13+
Strength, StrengthMode, StrengthType)
1214
from dglabv3.event import EventEmitter
1315
from dglabv3.music_to_wave import convert_audio_to_v3_protocol
16+
from dglabv3.wsmessage import WSMessage, WStype
1417

1518
logging.basicConfig(level=logging.DEBUG)
1619
logger = logging.getLogger("dglabv3")
@@ -58,7 +61,7 @@ def is_connected(self) -> bool:
5861
"""
5962
是否連接到WebSocket
6063
"""
61-
return self.client and self.client.sock and self.client.sock.connected
64+
return self.client is not None
6265

6366
def is_linked_to_app(self) -> bool:
6467
"""
@@ -110,6 +113,9 @@ async def connect(self) -> None:
110113

111114
async def _listen(self):
112115
try:
116+
if self.client is None:
117+
logger.error("WebSocket client is None")
118+
return
113119
async for message in self.client:
114120
await self._handle_message(message)
115121
except websockets.ConnectionClosed:
@@ -118,7 +124,7 @@ async def _listen(self):
118124
logger.error(f"WebSocket error: {e}")
119125
raise ConnectionError("WebSocket error")
120126

121-
def generate_qrcode(self) -> io.BytesIO:
127+
def generate_qrcode(self) -> Optional[io.BytesIO]:
122128
"""
123129
生成QR code圖片
124130
"""
@@ -130,11 +136,11 @@ def generate_qrcode(self) -> io.BytesIO:
130136
qr.make(fit=True)
131137
img = qr.make_image(fill_color="black", back_color="white")
132138
saveimg = io.BytesIO()
133-
img.save(saveimg, format="PNG")
139+
img.save(saveimg, format="PNG") # type: ignore
134140
saveimg.seek(0)
135141
return saveimg
136142

137-
def generate_qrcode_text(self) -> str:
143+
def generate_qrcode_text(self) -> Optional[str]:
138144
"""
139145
生成QR code文字
140146
"""
@@ -174,7 +180,7 @@ async def _heartbeat(self):
174180

175181
except websockets.ConnectionClosed:
176182
logger.info("WebSocket connection closed")
177-
self.close()
183+
await self.close()
178184
except Exception as e:
179185
logger.error(f"Heartbeat error: {e}")
180186

@@ -184,7 +190,7 @@ def _start_heartbeat(self):
184190
self._heartbeat_task.cancel()
185191
self._heartbeat_task = asyncio.create_task(self._heartbeat())
186192

187-
async def _handle_message(self, data: str):
193+
async def _handle_message(self, data: websockets.Data):
188194
try:
189195
message = json.loads(data)
190196
WSmsg = WSMessage(message)
@@ -195,14 +201,17 @@ async def _handle_message(self, data: str):
195201
self._bind_event.set()
196202

197203
elif WSmsg.type == WStype.MSG:
198-
if WSmsg.msg.startswith("feedback"):
199-
button = WSmsg.feedback()
200-
await self._dispatch_button(button)
201-
elif WSmsg.msg.startswith("strength"):
202-
self.strength.set_strength(WSmsg.strength())
203-
await self._dispatch_strength(WSmsg.strength())
204+
if WSmsg.msg is not None:
205+
if WSmsg.msg.startswith("feedback"):
206+
button = WSmsg.feedback()
207+
await self._dispatch_button(button)
208+
elif WSmsg.msg.startswith("strength"):
209+
self.strength.set_strength(WSmsg.strength())
210+
await self._dispatch_strength(WSmsg.strength())
211+
else:
212+
logger.warning(f"Unknown message type: {WSmsg.msg}")
204213
else:
205-
logger.warning(f"Unknown message type: {WSmsg.msg}")
214+
logger.warning("Received message with None content")
206215

207216
logger.debug(f"Received message: {message}")
208217
except Exception as e:
@@ -263,18 +272,19 @@ async def send_wave_message(self, wave: list[list[list[int]]], time: int = 10, c
263272
time: 30\n
264273
channel: Channel.A
265274
"""
266-
if channel == 1:
267-
channel = "A"
268-
elif channel == 2:
269-
channel = "B"
270-
elif channel == 3:
271-
channel = "BOTH"
275+
channel_str = ""
276+
if channel == Channel.A:
277+
channel_str = "A"
278+
elif channel == Channel.B:
279+
channel_str = "B"
280+
elif channel == Channel.BOTH:
281+
channel_str = "BOTH"
272282

273-
def _create_wave_message(channel: str, wave, time: int) -> dict:
283+
def _create_wave_message(ch_str: str, wave, time: int) -> dict:
274284
return {
275285
"type": MessageType.CLIENT_MSG,
276-
"channel": channel,
277-
"message": f"{channel}:{json.dumps(self._wave2hex(wave))}",
286+
"channel": ch_str,
287+
"message": f"{ch_str}:{json.dumps(self._wave2hex(wave))}",
278288
"time": time,
279289
}
280290

@@ -283,12 +293,12 @@ def _create_wave_message(channel: str, wave, time: int) -> dict:
283293
# message2 : B通道波形数据(16进制HEX数组json,具体见上面的协议说明)
284294
# time1 : A通道波形数据持续发送时长
285295
# time2 : B通道波形数据持续发送时长
286-
if channel == "BOTH":
296+
if channel_str == "BOTH":
287297
for ch in ["A", "B"]:
288298
message = _create_wave_message(ch, wave, time)
289299
await self._send_message(message)
290300
else:
291-
message = _create_wave_message(channel, wave, time)
301+
message = _create_wave_message(channel_str, wave, time)
292302
await self._send_message(message)
293303

294304
async def clear_wave(self, channel: Channel):

dglabv3/dtype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
from dataclasses import dataclass, field
12
from enum import Enum, IntEnum, StrEnum
23
from typing import Final
3-
from dataclasses import dataclass, field
44

55
__all__ = ["ChannelStrength", "StrengthType", "StrengthMode", "MessageType", "Channel"]
66

dglabv3/event.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import asyncio
2-
from typing import Callable, Any, Dict, List
32
import functools
43
import logging
5-
4+
from typing import Any, Callable, Dict, List, Optional
65

76
logger = logging.getLogger("dglabv3.event")
87

98

10-
def event(name: str = None):
9+
def event(name: Optional[str] = None):
1110
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
1211
@functools.wraps(func)
1312
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:

dglabv3/music_to_wave.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import librosa
21
import math
3-
import numpy as np
2+
3+
import librosa
44
import matplotlib.pyplot as plt
5+
import numpy as np
56

67

78
def convert_audio_to_v3_protocol(mp3_file_path: str) -> list:

dglabv3/wsmessage.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from enum import Enum
22
from typing import Optional
3+
34
from dglabv3.dtype import Button, Strength
45

56

@@ -11,13 +12,10 @@ class WStype(Enum):
1112
ERROR = "error"
1213

1314

14-
15-
16-
1715
class WSMessage:
1816
def __init__(self, data: dict):
1917
self.type: WStype = WStype(data.get("type"))
20-
self.msg: Optional[dict] = data.get("message", None)
18+
self.msg: Optional[str] = data.get("message", None) # 將型別從 dict 改為 str
2119
self.targetID: Optional[str] = data.get("targetId", None)
2220
self.clientID: Optional[str] = data.get("clientId", None)
2321

@@ -34,8 +32,12 @@ def to_dict(self) -> dict:
3432
}
3533

3634
def feedback(self) -> Button:
35+
if self.msg is None:
36+
raise ValueError("Message is None, cannot get feedback")
3737
return Button(self.msg.split("-")[1])
3838

3939
def strength(self) -> Strength:
40+
if self.msg is None:
41+
raise ValueError("Message is None, cannot get strength")
4042
splitmsg = self.msg.split("-")[1].split("+")
4143
return Strength(A=int(splitmsg[0]), B=int(splitmsg[1]), MAXA=int(splitmsg[2]), MAXB=int(splitmsg[3]))

example/connect.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import asyncio
2-
from PIL import Image
3-
import sys
42
import os
3+
import sys
54

6-
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7-
from dglabv3 import dglabv3
8-
from dglabv3 import Channel, StrengthType, PULSES
5+
from PIL import Image
96

7+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8+
from dglabv3 import PULSES, Channel, StrengthType, dglabv3
109

1110
client = dglabv3()
1211

@@ -15,6 +14,9 @@ async def run():
1514
try:
1615
await client.connect_and_wait()
1716
qrcode = client.generate_qrcode()
17+
if qrcode is None:
18+
print("Failed to generate QR code.")
19+
return
1820
ig = Image.open(qrcode)
1921
ig.show()
2022
await client.wait_for_app_connect()

example/discordbot.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import random
2+
13
import discord
24
from discord import app_commands
3-
import random
45
from discord.ext import commands
5-
from dglabv3 import dglabv3, Channel, Pulse, ALL_PULSES, PULSES, Strength
6+
7+
from dglabv3 import ALL_PULSES, PULSES, Channel, Pulse, Strength, dglabv3
68

79

810
class dglab:

example/eventtest.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import asyncio
2-
from PIL import Image
3-
import sys
42
import os
3+
import sys
54

5+
from PIL import Image
66

77
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8-
from dglabv3 import dglabv3
9-
from dglabv3 import Channel, StrengthType, PULSES, Strength, Button
10-
8+
from dglabv3 import PULSES, Button, Channel, Strength, StrengthType, dglabv3
119

1210
client = dglabv3()
1311

@@ -26,6 +24,9 @@ async def run():
2624
try:
2725
await client.connect_and_wait()
2826
qrcode = client.generate_qrcode()
27+
if qrcode is None:
28+
print("Failed to generate QR code.")
29+
return
2930
ig = Image.open(qrcode)
3031
ig.show()
3132

0 commit comments

Comments
 (0)