-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathwebsocket_transport.py
More file actions
49 lines (41 loc) · 1.58 KB
/
websocket_transport.py
File metadata and controls
49 lines (41 loc) · 1.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import json
import asyncio
import logging
from typing import AsyncGenerator, Any
import websockets
_logger = logging.getLogger(__name__)
class WebSocketTransport:
"""Transport that communicates over a WebSocket connection."""
def __init__(self, ws_url: str):
self.ws_url = ws_url
async def __call__(
self, send_queue: asyncio.Queue[dict]
) -> AsyncGenerator[dict[str, Any], None]:
_logger.info("Connecting to WebSocket at %s", self.ws_url)
async with websockets.connect(self.ws_url, open_timeout=60) as ws:
sender_task = asyncio.create_task(self._queue_sender(ws, send_queue))
try:
async for message in ws:
try:
data = json.loads(message)
yield data
except json.JSONDecodeError:
yield {"raw": message}
except websockets.exceptions.ConnectionClosed as e:
e.add_note(f"Send queue is empty: {send_queue.empty()}")
finally:
sender_task.cancel()
try:
await sender_task
except asyncio.CancelledError:
pass
_logger.info("Agent runtime connection closed")
async def _queue_sender(
self, ws: websockets.ClientConnection, send_queue: asyncio.Queue[dict]
) -> None:
while True:
message = await send_queue.get()
try:
await ws.send(json.dumps(message))
finally:
send_queue.task_done()