1+ import asyncio
2+ import io
13import json
24import logging
3- import qrcode
4- import io
5- import asyncio
65from threading import Event
7- from websockets .asyncio .client import connect as ws_connect
6+ from typing import Optional
7+
8+ import qrcode
89import websockets
10+ from websockets .asyncio .client import connect as ws_connect
911
10- from dglabv3 .dtype import Button , Channel , StrengthType , StrengthMode , MessageType , ChannelStrength , Strength
11- from dglabv3 . wsmessage import WSMessage , WStype
12+ from dglabv3 .dtype import ( Button , Channel , ChannelStrength , MessageType ,
13+ Strength , StrengthMode , StrengthType )
1214from dglabv3 .event import EventEmitter
1315from dglabv3 .music_to_wave import convert_audio_to_v3_protocol
16+ from dglabv3 .wsmessage import WSMessage , WStype
1417
1518logging .basicConfig (level = logging .DEBUG )
1619logger = logging .getLogger ("dglabv3" )
@@ -58,7 +61,7 @@ def is_connected(self) -> bool:
5861 """
5962 是否連接到WebSocket
6063 """
61- return self .client and self . client . sock and self . client . sock . connected
64+ return self .client is not None
6265
6366 def is_linked_to_app (self ) -> bool :
6467 """
@@ -110,6 +113,9 @@ async def connect(self) -> None:
110113
111114 async def _listen (self ):
112115 try :
116+ if self .client is None :
117+ logger .error ("WebSocket client is None" )
118+ return
113119 async for message in self .client :
114120 await self ._handle_message (message )
115121 except websockets .ConnectionClosed :
@@ -118,7 +124,7 @@ async def _listen(self):
118124 logger .error (f"WebSocket error: { e } " )
119125 raise ConnectionError ("WebSocket error" )
120126
121- def generate_qrcode (self ) -> io .BytesIO :
127+ def generate_qrcode (self ) -> Optional [ io .BytesIO ] :
122128 """
123129 生成QR code圖片
124130 """
@@ -130,11 +136,11 @@ def generate_qrcode(self) -> io.BytesIO:
130136 qr .make (fit = True )
131137 img = qr .make_image (fill_color = "black" , back_color = "white" )
132138 saveimg = io .BytesIO ()
133- img .save (saveimg , format = "PNG" )
139+ img .save (saveimg , format = "PNG" ) # type: ignore
134140 saveimg .seek (0 )
135141 return saveimg
136142
137- def generate_qrcode_text (self ) -> str :
143+ def generate_qrcode_text (self ) -> Optional [ str ] :
138144 """
139145 生成QR code文字
140146 """
@@ -174,7 +180,7 @@ async def _heartbeat(self):
174180
175181 except websockets .ConnectionClosed :
176182 logger .info ("WebSocket connection closed" )
177- self .close ()
183+ await self .close ()
178184 except Exception as e :
179185 logger .error (f"Heartbeat error: { e } " )
180186
@@ -184,7 +190,7 @@ def _start_heartbeat(self):
184190 self ._heartbeat_task .cancel ()
185191 self ._heartbeat_task = asyncio .create_task (self ._heartbeat ())
186192
187- async def _handle_message (self , data : str ):
193+ async def _handle_message (self , data : websockets . Data ):
188194 try :
189195 message = json .loads (data )
190196 WSmsg = WSMessage (message )
@@ -195,14 +201,17 @@ async def _handle_message(self, data: str):
195201 self ._bind_event .set ()
196202
197203 elif WSmsg .type == WStype .MSG :
198- if WSmsg .msg .startswith ("feedback" ):
199- button = WSmsg .feedback ()
200- await self ._dispatch_button (button )
201- elif WSmsg .msg .startswith ("strength" ):
202- self .strength .set_strength (WSmsg .strength ())
203- await self ._dispatch_strength (WSmsg .strength ())
204+ if WSmsg .msg is not None :
205+ if WSmsg .msg .startswith ("feedback" ):
206+ button = WSmsg .feedback ()
207+ await self ._dispatch_button (button )
208+ elif WSmsg .msg .startswith ("strength" ):
209+ self .strength .set_strength (WSmsg .strength ())
210+ await self ._dispatch_strength (WSmsg .strength ())
211+ else :
212+ logger .warning (f"Unknown message type: { WSmsg .msg } " )
204213 else :
205- logger .warning (f"Unknown message type: { WSmsg . msg } " )
214+ logger .warning ("Received message with None content " )
206215
207216 logger .debug (f"Received message: { message } " )
208217 except Exception as e :
@@ -263,18 +272,19 @@ async def send_wave_message(self, wave: list[list[list[int]]], time: int = 10, c
263272 time: 30\n
264273 channel: Channel.A
265274 """
266- if channel == 1 :
267- channel = "A"
268- elif channel == 2 :
269- channel = "B"
270- elif channel == 3 :
271- channel = "BOTH"
275+ channel_str = ""
276+ if channel == Channel .A :
277+ channel_str = "A"
278+ elif channel == Channel .B :
279+ channel_str = "B"
280+ elif channel == Channel .BOTH :
281+ channel_str = "BOTH"
272282
273- def _create_wave_message (channel : str , wave , time : int ) -> dict :
283+ def _create_wave_message (ch_str : str , wave , time : int ) -> dict :
274284 return {
275285 "type" : MessageType .CLIENT_MSG ,
276- "channel" : channel ,
277- "message" : f"{ channel } :{ json .dumps (self ._wave2hex (wave ))} " ,
286+ "channel" : ch_str ,
287+ "message" : f"{ ch_str } :{ json .dumps (self ._wave2hex (wave ))} " ,
278288 "time" : time ,
279289 }
280290
@@ -283,12 +293,12 @@ def _create_wave_message(channel: str, wave, time: int) -> dict:
283293 # message2 : B通道波形数据(16进制HEX数组json,具体见上面的协议说明)
284294 # time1 : A通道波形数据持续发送时长
285295 # time2 : B通道波形数据持续发送时长
286- if channel == "BOTH" :
296+ if channel_str == "BOTH" :
287297 for ch in ["A" , "B" ]:
288298 message = _create_wave_message (ch , wave , time )
289299 await self ._send_message (message )
290300 else :
291- message = _create_wave_message (channel , wave , time )
301+ message = _create_wave_message (channel_str , wave , time )
292302 await self ._send_message (message )
293303
294304 async def clear_wave (self , channel : Channel ):
0 commit comments