|
1 | 1 | import ctypes |
2 | 2 | import socket |
3 | 3 | from abc import ABCMeta, abstractmethod |
4 | | -from asyncio import get_running_loop |
5 | 4 | from contextlib import asynccontextmanager |
6 | 5 | from dataclasses import dataclass, field |
7 | 6 | from os import getenv, getuid |
|
14 | 13 | create_connected_udp_socket, |
15 | 14 | create_memory_object_stream, |
16 | 15 | ) |
17 | | -from anyio._backends._asyncio import SocketStream, StreamProtocol |
| 16 | +from anyio._core._eventloop import get_async_backend |
18 | 17 | from anyio.streams.stapled import StapledObjectStream |
19 | 18 |
|
20 | 19 | from .streams.websocket import WebsocketClientStream |
@@ -120,6 +119,7 @@ async def address(self): |
120 | 119 | else: |
121 | 120 | raise ValueError("enable_address mode is not true in the exporter configuration") |
122 | 121 |
|
| 122 | + |
123 | 123 | @dataclass(kw_only=True) |
124 | 124 | class UdpNetwork(NetworkInterface, Driver): |
125 | 125 | ''' |
@@ -207,12 +207,7 @@ async def connect(self): |
207 | 207 | if libc.connect(sock.fileno(), ctypes.byref(addr), ctypes.sizeof(addr)) < 0: |
208 | 208 | raise OSError(ctypes.get_errno(), "vsock connect() failed") |
209 | 209 |
|
210 | | - transport, protocol = await get_running_loop().create_connection( |
211 | | - StreamProtocol, |
212 | | - sock=sock, |
213 | | - ) |
214 | | - |
215 | | - yield SocketStream(transport, protocol) |
| 210 | + yield await get_async_backend().wrap_stream_socket(sock) |
216 | 211 |
|
217 | 212 |
|
218 | 213 | @dataclass(kw_only=True) |
@@ -301,26 +296,27 @@ class EchoNetwork(NetworkInterface, Driver): |
301 | 296 | @exportstream |
302 | 297 | @asynccontextmanager |
303 | 298 | async def connect(self): |
304 | | - tx, rx = create_memory_object_stream[bytes](32) # ty: ignore[call-non-callable] |
| 299 | + tx, rx = create_memory_object_stream[bytes](32) # ty: ignore[call-non-callable] |
305 | 300 | self.logger.debug("Connecting Echo") |
306 | 301 | async with StapledObjectStream(tx, rx) as stream: |
307 | 302 | yield stream |
308 | 303 |
|
309 | 304 |
|
310 | 305 | @dataclass(kw_only=True) |
311 | 306 | class WebsocketNetwork(NetworkInterface, Driver): |
312 | | - ''' |
| 307 | + """ |
313 | 308 | Handles websocket connections from a given url. |
314 | | - ''' |
| 309 | + """ |
| 310 | + |
315 | 311 | url: str |
316 | 312 | enable_address: bool = True |
317 | 313 |
|
318 | 314 | @exportstream |
319 | 315 | @asynccontextmanager |
320 | 316 | async def connect(self): |
321 | | - ''' |
| 317 | + """ |
322 | 318 | Create a websocket connection to `self.url` and streams its output. |
323 | | - ''' |
| 319 | + """ |
324 | 320 | self.logger.info("Connecting to %s", self.url) |
325 | 321 |
|
326 | 322 | async with websockets.connect(self.url) as websocket: |
|
0 commit comments