diff --git a/CHANGES/12822.feature.rst b/CHANGES/12822.feature.rst new file mode 100644 index 00000000000..9f2f116f71f --- /dev/null +++ b/CHANGES/12822.feature.rst @@ -0,0 +1,3 @@ +Added ``aiofastnet`` package to ``speedups`` extra. aiofastnet provides faster alternatives to the standard loop functions, which are used to run server or establish connections. If you experience any issues that you think might be related to this change, you can try to disable ``aiofastnet`` by uninstalling aiofastnet package. + +-- by :user:`tarasko`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 2aaa0a02403..837b8456f9d 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -367,6 +367,7 @@ Sunit Deshpande Sviatoslav Bulbakha Sviatoslav Sydorenko Taha Jahangir +Taras Kozlov Taras Voinarovskyi Terence Honles Thanos Lefteris diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 90cf6e11046..89f80c709da 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -53,6 +53,12 @@ from .log import client_logger from .resolver import DefaultResolver +try: + import aiofastnet +except ImportError: + aiofastnet = None # type: ignore[assignment] + + if sys.version_info >= (3, 12): from collections.abc import Buffer else: @@ -98,6 +104,82 @@ from .tracing import Trace +async def create_connection( + loop: asyncio.AbstractEventLoop, + protocol_factory: Callable[[], ResponseHandler], + *, + ssl: SSLContext | None, + sock: socket.socket, + server_hostname: str | None, + ssl_shutdown_timeout: float | None = None, +) -> tuple[asyncio.Transport, ResponseHandler]: + if aiofastnet is not None: + return await aiofastnet.create_connection( + loop, + protocol_factory, + ssl=ssl, + sock=sock, + server_hostname=server_hostname, + ssl_shutdown_timeout=ssl_shutdown_timeout, + ) + else: + if sys.version_info >= (3, 11): # type: ignore[unreachable] + return await loop.create_connection( + protocol_factory, + ssl=ssl, + sock=sock, + server_hostname=server_hostname, + ssl_shutdown_timeout=ssl_shutdown_timeout, + ) + else: + return await loop.create_connection( + protocol_factory, + ssl=ssl, + sock=sock, + server_hostname=server_hostname, + ) + + +async def start_tls( + loop: asyncio.AbstractEventLoop, + transport: asyncio.Transport, + protocol: ResponseHandler, + sslcontext: SSLContext, + *, + server_hostname: str | None, + ssl_handshake_timeout: float | None, + ssl_shutdown_timeout: float | None = None, +) -> asyncio.BaseTransport | None: + if aiofastnet is not None: + return await aiofastnet.start_tls( + loop, + transport, + protocol, + sslcontext, + server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout, + ) + else: + if sys.version_info >= (3, 11): # type: ignore[unreachable] + return await loop.start_tls( + transport, + protocol, + sslcontext, + server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_shutdown_timeout=ssl_shutdown_timeout, + ) + else: + return await loop.start_tls( + transport, + protocol, + sslcontext, + server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + ) + + class Connection: """Represents a single connection.""" @@ -1266,7 +1348,7 @@ async def _wrap_create_connection( and sys.version_info >= (3, 11) ): kwargs["ssl_shutdown_timeout"] = self._ssl_shutdown_timeout - return await self._loop.create_connection(*args, **kwargs, sock=sock) + return await create_connection(self._loop, *args, **kwargs, sock=sock) except cert_errors as exc: raise ClientConnectorCertificateError(req.connection_key, exc) from exc except ssl_errors as exc: @@ -1297,10 +1379,14 @@ def _warn_about_tls_in_tls( return # Support in asyncio was added in Python 3.11 (bpo-44011) - asyncio_supports_tls_in_tls = sys.version_info >= (3, 11) or getattr( - underlying_transport, - "_start_tls_compatible", - False, + asyncio_supports_tls_in_tls = ( + sys.version_info >= (3, 11) + or getattr( + underlying_transport, + "_start_tls_compatible", + False, + ) + or aiofastnet is not None ) if asyncio_supports_tls_in_tls: @@ -1347,7 +1433,8 @@ async def _start_tls_connection( try: # ssl_shutdown_timeout is only available in Python 3.11+ if sys.version_info >= (3, 11) and self._ssl_shutdown_timeout: - tls_transport = await self._loop.start_tls( + tls_transport = await start_tls( + self._loop, underlying_transport, tls_proto, sslcontext, @@ -1356,7 +1443,8 @@ async def _start_tls_connection( ssl_shutdown_timeout=self._ssl_shutdown_timeout, ) else: - tls_transport = await self._loop.start_tls( + tls_transport = await start_tls( + self._loop, underlying_transport, tls_proto, sslcontext, diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index d45afc0dd7d..fab821caf94 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -9,7 +9,7 @@ from mimetypes import MimeTypes from stat import S_ISREG from types import MappingProxyType -from typing import IO, TYPE_CHECKING, Any, Final, Optional +from typing import TYPE_CHECKING, BinaryIO, Final, Optional from . import hdrs from .abc import AbstractStreamWriter @@ -30,10 +30,28 @@ if TYPE_CHECKING: from .web_request import BaseRequest +try: + import aiofastnet +except ImportError: + aiofastnet = None # type: ignore[assignment] + _T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]] +async def sendfile( + loop: asyncio.AbstractEventLoop, + transport: asyncio.Transport, + file: BinaryIO, + offset: int, + count: int, +) -> None: + if aiofastnet is not None: + await aiofastnet.sendfile(loop, transport, file, offset, count) + else: + await loop.sendfile(transport, file, offset, count) # type: ignore[unreachable] + + NOSENDFILE: Final[bool] = bool(os.environ.get("AIOHTTP_NOSENDFILE")) CONTENT_TYPES: Final[MimeTypes] = MimeTypes() @@ -92,12 +110,12 @@ def __init__( self._path = pathlib.Path(path) self._chunk_size = chunk_size - def _seek_and_read(self, fobj: IO[Any], offset: int, chunk_size: int) -> bytes: + def _seek_and_read(self, fobj: BinaryIO, offset: int, chunk_size: int) -> bytes: fobj.seek(offset) - return fobj.read(chunk_size) # type: ignore[no-any-return] + return fobj.read(chunk_size) async def _sendfile_fallback( - self, writer: AbstractStreamWriter, fobj: IO[Any], offset: int, count: int + self, writer: AbstractStreamWriter, fobj: BinaryIO, offset: int, count: int ) -> AbstractStreamWriter: # To keep memory usage low,fobj is transferred in chunks # controlled by the constructor's chunk_size argument. @@ -118,7 +136,7 @@ async def _sendfile_fallback( return writer async def _sendfile( - self, request: "BaseRequest", fobj: IO[Any], offset: int, count: int + self, request: "BaseRequest", fobj: BinaryIO, offset: int, count: int ) -> AbstractStreamWriter: writer = await super().prepare(request) assert writer is not None @@ -132,7 +150,7 @@ async def _sendfile( raise ConnectionResetError("Connection lost") try: - await loop.sendfile(transport, fobj, offset, count) + await sendfile(loop, transport, fobj, offset, count) except NotImplementedError: return await self._sendfile_fallback(writer, fobj, offset, count) diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index 82c3bd277f8..3c5feaf034d 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -2,6 +2,7 @@ import signal import socket from abc import ABC, abstractmethod +from collections.abc import Callable from typing import Any, Generic, TypeVar from yarl import URL @@ -21,6 +22,49 @@ except ImportError: # pragma: no cover SSLContext = object # type: ignore[misc,assignment] +try: + import aiofastnet +except ImportError: + aiofastnet = None # type: ignore[assignment] + + +async def create_server( + loop: asyncio.AbstractEventLoop, + protocol_factory: Callable[[], asyncio.Protocol], + host: str | None = None, + port: int | None = None, + *, + sock: socket.socket | None = None, + ssl: SSLContext | None = None, + backlog: int = 100, + reuse_address: bool | None = None, + reuse_port: bool | None = None, +) -> asyncio.Server: + if aiofastnet is not None: + return await aiofastnet.create_server( + loop, + protocol_factory, + host, + port, + sock=sock, + ssl=ssl, + backlog=backlog, + reuse_address=reuse_address, + reuse_port=reuse_port, + ) + else: + return await loop.create_server( # type: ignore[unreachable] + protocol_factory, + host, + port, + sock=sock, + ssl=ssl, + backlog=backlog, + reuse_address=reuse_address, + reuse_port=reuse_port, + ) + + __all__ = ( "BaseSite", "TCPSite", @@ -130,7 +174,8 @@ async def start(self) -> None: loop = asyncio.get_running_loop() server = self._runner.server assert server is not None - self._server = await loop.create_server( + self._server = await create_server( + loop, server, self._host, self._port, @@ -244,8 +289,8 @@ async def start(self) -> None: loop = asyncio.get_running_loop() server = self._runner.server assert server is not None - self._server = await loop.create_server( - server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog + self._server = await create_server( + loop, server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog ) diff --git a/docs/faq.rst b/docs/faq.rst index 3f50b855588..9b28b54b659 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -263,6 +263,67 @@ enable compression in NGINX (you are deploying aiohttp behind reverse proxy, right?). +How do I enable Kernel TLS, and should I do it? +----------------------------------------------- + +Kernel TLS (KTLS) allows aiohttp to move encryption and decryption of +TLS traffic from user space to the kernel. It was added to the Linux kernel in +4.13, but full support for TLS 1.3 and modern ciphers is available only +since 5.19. + +KTLS will be beneficial if you run an HTTPS server that often returns +:class:`~aiohttp.web.FileResponse` objects or you have a high-end NIC that can +offload TLS encryption. For ordinary +dynamic responses, small files, or deployments behind a TLS-terminating reverse +proxy, it is unlikely to help and may actually slightly degrade performance. + +KTLS is supported through the ``aiofastnet`` package, which is installed as +part of the ``speedups`` extra. + +To enable KTLS, you have to do and check the following: + +* Verify that ``aiofastnet`` is installed and can be imported. + + Currently, ``aiofastnet`` works only with CPython distributions that are + dynamically linked against OpenSSL. This is generally true for system Python + installations, Conda distributions, ``pyenv``, and + ``actions/setup-python`` in GitHub Actions, but not for Python installations + managed by ``uv``. + + .. code-block:: python + + try: + import aiofastnet + except ImportError: + aiofastnet = None + +* Make sure the Linux ``tls`` kernel module is loaded:: + + sudo modprobe tls + +* Make sure the ``ssl.OP_ENABLE_KTLS`` option is enabled in ``SSLContext`` + (available since Python 3.12):: + + sslcontext.options |= ssl.OP_ENABLE_KTLS + +* Make sure Python is using OpenSSL 3.0 or newer. OpenSSL should have been + built on a machine whose Linux headers are new enough. OpenSSL needs Linux + headers at least 4.13.0 to build the transmit path; older headers make it + skip KTLS support. Typically, Python is using the system OpenSSL on Linux, + but some times distributions ship their own OpenSSL. The following commands + will help identify the OpenSSL version and which ``libssl`` and ``libcrypto`` + are being used by the ``ssl`` module:: + + python -c "import ssl; print(ssl.OPENSSL_VERSION)" + ldd "$(python -c 'import _ssl; print(_ssl.__file__)')" + + +If ``ssl.OP_ENABLE_KTLS`` was requested in ``sslcontext``, but ``aiofastnet`` +could not enable KTLS, it will log a warning suggesting the possible reason. + +After enabling it, run your own benchmarks and verify that KTLS actually +speeds things up in your case. + How do I manage a ClientSession within a web server? ---------------------------------------------------- diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index d803e9f526c..06972a16fdd 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -4,6 +4,7 @@ ABI addons aiodns aioes +aiofastnet aiohttp aiohttpdemo aiohttp’s @@ -81,6 +82,7 @@ codspeed Codings committer committers +Conda config Config configs @@ -184,6 +186,7 @@ keepalive keepalived keepalives keepaliving +KTLS kib KiB kwarg @@ -228,6 +231,7 @@ namedtuple nameservers namespace netrc +NIC nginx Nginx Nikolay @@ -237,6 +241,7 @@ nowait OAuth Online optimizations +OpenSSL orjson os outcoming diff --git a/pyproject.toml b/pyproject.toml index 0c27cc88bb5..7bf480ac1a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dynamic = [ [project.optional-dependencies] speedups = [ "aiodns >= 3.3.0; sys_platform != 'android' and sys_platform != 'ios'", + "aiofastnet >= 0.11.0; platform_python_implementation == 'CPython' and (platform_machine == 'x86_64' or platform_machine == 'AMD64' or platform_machine == 'aarch64')", "Brotli >= 1.2; platform_python_implementation == 'CPython' and sys_platform != 'android' and sys_platform != 'ios'", "brotlicffi >= 1.2; platform_python_implementation != 'CPython'", "backports.zstd; platform_python_implementation == 'CPython' and python_version < '3.14' and sys_platform != 'android' and sys_platform != 'ios'", diff --git a/requirements/lint.in b/requirements/lint.in index c0a86f2435f..dc88ce01512 100644 --- a/requirements/lint.in +++ b/requirements/lint.in @@ -1,4 +1,5 @@ aiodns +aiofastnet >= 0.11.0 backports.zstd; implementation_name == "cpython" and python_version < "3.14" blockbuster freezegun diff --git a/requirements/runtime-deps.in b/requirements/runtime-deps.in index d70fc5a9dbc..ec66d1e8f1b 100644 --- a/requirements/runtime-deps.in +++ b/requirements/runtime-deps.in @@ -1,6 +1,7 @@ # Extracted from `pyproject.toml` via `make sync-direct-runtime-deps` aiodns >= 3.3.0; sys_platform != 'android' and sys_platform != 'ios' +aiofastnet >= 0.11.0; platform_python_implementation == 'CPython' and (platform_machine == 'x86_64' or platform_machine == 'AMD64' or platform_machine == 'aarch64') aiohappyeyeballs >= 2.5.0 aiosignal >= 1.4.0 async-timeout >= 4.0, < 6.0 ; python_version < '3.11' diff --git a/tests/conftest.py b/tests/conftest.py index 3869d93794e..3e1a6b2e9fc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -109,6 +109,10 @@ def blockbuster(request: pytest.FixtureRequest) -> Iterator[None]: # synchronization in async code. # Allow lock.acquire calls to prevent these false positives bb.functions["threading.Lock.acquire"].deactivate() + + # aiofastnet is using sendfile on a non-blocking socket. + # blockbuster triggers anyway. Seems like a false positive + bb.functions["os.sendfile"].deactivate() yield diff --git a/tests/test_connector.py b/tests/test_connector.py index 259dfb495e5..65bc3a33230 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -672,7 +672,7 @@ async def test_tcp_connector_certificate_error( conn = aiohttp.TCPConnector() with mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, @@ -695,7 +695,7 @@ async def test_tcp_connector_server_hostname_default( conn = aiohttp.TCPConnector() with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -713,7 +713,7 @@ async def test_tcp_connector_server_hostname_override( conn = aiohttp.TCPConnector() with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -870,7 +870,7 @@ def get_extra_info(param: str) -> object: side_effect=_resolve_host, ), mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, @@ -971,7 +971,7 @@ async def create_connection( side_effect=sock_connect, ): with mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, @@ -1064,7 +1064,7 @@ async def create_connection( side_effect=_resolve_host, ), mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, @@ -1146,7 +1146,7 @@ async def create_connection( side_effect=sock_connect, ): with mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, @@ -1259,7 +1259,7 @@ async def create_connection( side_effect=_resolve_host, ), mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, @@ -2249,7 +2249,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( conn = aiohttp.TCPConnector(ssl_shutdown_timeout=2.5) with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -2267,7 +2267,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( conn = aiohttp.TCPConnector(ssl_shutdown_timeout=None) with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -2286,7 +2286,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( conn = aiohttp.TCPConnector(ssl_shutdown_timeout=2.5) with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -2314,7 +2314,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_not_passed_pre_311( assert any(issubclass(warn.category, RuntimeWarning) for warn in w) with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -2472,7 +2472,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_zero_not_passed( conn = aiohttp.TCPConnector(ssl_shutdown_timeout=0) with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -2504,7 +2504,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_nonzero_passed( conn = aiohttp.TCPConnector(ssl_shutdown_timeout=5.0) with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -2573,7 +2573,9 @@ async def test_start_tls_exception_with_ssl_shutdown_timeout_zero() -> None: mock.patch.object( conn, "_get_ssl_context", return_value=ssl.create_default_context() ), - mock.patch.object(conn._loop, "start_tls", side_effect=OSError("TLS failed")), + mock.patch.object( + connector_module, "start_tls", side_effect=OSError("TLS failed") + ), ): with pytest.raises(OSError): await conn._start_tls_connection(underlying_transport, req, ClientTimeout()) @@ -2605,7 +2607,9 @@ async def test_start_tls_exception_with_ssl_shutdown_timeout_nonzero() -> None: mock.patch.object( conn, "_get_ssl_context", return_value=ssl.create_default_context() ), - mock.patch.object(conn._loop, "start_tls", side_effect=OSError("TLS failed")), + mock.patch.object( + connector_module, "start_tls", side_effect=OSError("TLS failed") + ), ): with pytest.raises(OSError): await conn._start_tls_connection(underlying_transport, req, ClientTimeout()) @@ -2640,7 +2644,9 @@ async def test_start_tls_exception_with_ssl_shutdown_timeout_nonzero_pre_311() - mock.patch.object( conn, "_get_ssl_context", return_value=ssl.create_default_context() ), - mock.patch.object(conn._loop, "start_tls", side_effect=OSError("TLS failed")), + mock.patch.object( + connector_module, "start_tls", side_effect=OSError("TLS failed") + ), ): with pytest.raises(OSError): await conn._start_tls_connection(underlying_transport, req, ClientTimeout()) @@ -4102,7 +4108,7 @@ async def _resolve_host( first_conn = next(iter(conn._conns.values()))[0][0] assert first_conn.transport is not None - _sslcontext = first_conn.transport._ssl_protocol._sslcontext # type: ignore[attr-defined] + _sslcontext = first_conn.transport.get_extra_info("sslcontext") assert _sslcontext is client_ssl_ctx r.close() @@ -4589,7 +4595,7 @@ async def test_tcp_connector_socket_factory( ) with mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 9cd8b3f1d6a..71e25aef48b 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -10,7 +10,7 @@ from yarl import URL import aiohttp -from aiohttp import hdrs +from aiohttp import connector as connector_module, hdrs from aiohttp.abc import AbstractStreamWriter from aiohttp.client_reqrep import ( ClientRequest, @@ -70,7 +70,7 @@ async def test_connect( # type: ignore[misc] } ) with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(proto.transport, proto), @@ -131,7 +131,7 @@ async def test_proxy_headers( # type: ignore[misc] } ) with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(proto.transport, proto), @@ -204,7 +204,7 @@ async def test_proxy_connection_error( # type: ignore[misc] } with mock.patch.object(connector, "_resolve_host", autospec=True, return_value=[r]): with mock.patch.object( - connector._loop, + connector_module, "create_connection", autospec=True, side_effect=OSError("dont take it serious"), @@ -274,13 +274,13 @@ async def test_proxy_server_hostname_default( # type: ignore[misc] ): tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, return_value=mock.Mock(), @@ -360,13 +360,13 @@ async def test_proxy_server_hostname_override( # type: ignore[misc] ): tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, return_value=mock.Mock(), @@ -482,14 +482,14 @@ def close(self) -> None: return_value=fingerprint_mock, ), mock.patch.object( # Called on connection to http://proxy.example.com - event_loop, + connector_module, "create_connection", autospec=True, spec_set=True, return_value=(mock.Mock(), mock.Mock()), ), mock.patch.object( # Called on connection to https://www.python.org - event_loop, + connector_module, "start_tls", autospec=True, spec_set=True, @@ -561,13 +561,13 @@ async def test_https_connect( # type: ignore[misc] ): tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, return_value=mock.Mock(), @@ -647,14 +647,14 @@ async def test_https_connect_certificate_error( # type: ignore[misc] tr, proto = mock.Mock(), mock.Mock() # Called on connection to http://proxy.example.com with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): # Called on connection to https://www.python.org with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, side_effect=ssl.CertificateError, @@ -728,14 +728,14 @@ async def test_https_connect_ssl_error( # type: ignore[misc] tr, proto = mock.Mock(), mock.Mock() # Called on connection to http://proxy.example.com with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): # Called on connection to https://www.python.org with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, side_effect=ssl.SSLError, @@ -811,7 +811,7 @@ async def test_https_connect_http_proxy_error( # type: ignore[misc] tr.get_extra_info.return_value = None # Called on connection to http://proxy.example.com with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), @@ -891,7 +891,7 @@ async def test_https_connect_resp_start_error( # type: ignore[misc] tr.get_extra_info.return_value = None # Called on connection to http://proxy.example.com with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), @@ -940,7 +940,10 @@ async def test_request_port( # type: ignore[misc] tr.get_extra_info.return_value = None # Called on connection to http://proxy.example.com with mock.patch.object( - event_loop, "create_connection", autospec=True, return_value=(tr, proto) + connector_module, + "create_connection", + autospec=True, + return_value=(tr, proto), ): req = make_client_request( "GET", @@ -1008,13 +1011,13 @@ async def test_https_connect_pass_ssl_context( # type: ignore[misc] ): tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, return_value=mock.Mock(), @@ -1031,6 +1034,7 @@ async def test_https_connect_pass_ssl_context( # type: ignore[misc] # ssl_shutdown_timeout=0 is not passed to start_tls tls_m.assert_called_with( + event_loop, mock.ANY, mock.ANY, _SSL_CONTEXT_VERIFIED, @@ -1103,13 +1107,13 @@ async def test_https_auth( # type: ignore[misc] ) as host_m: tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, return_value=mock.Mock(), diff --git a/tests/test_proxy_functional.py b/tests/test_proxy_functional.py index eefe51db251..68b3da63220 100644 --- a/tests/test_proxy_functional.py +++ b/tests/test_proxy_functional.py @@ -7,7 +7,7 @@ from collections.abc import Awaitable, Callable, Iterator from contextlib import suppress from re import match as match_regex -from typing import TYPE_CHECKING, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict from unittest import mock from uuid import uuid4 @@ -28,8 +28,6 @@ else: proxy = pytest.importorskip("proxy") -ASYNCIO_SUPPORTS_TLS_IN_TLS = sys.version_info >= (3, 11) - class _ResponseArgs(TypedDict): status: int @@ -49,7 +47,6 @@ async def get_request( ) -> ClientResponse: ... else: - from typing import Any async def get_request( method: str = "GET", @@ -66,6 +63,15 @@ async def get_request( return resp +try: + import aiofastnet +except ImportError: + aiofastnet = None # type: ignore[assignment] + + +ASYNCIO_SUPPORTS_TLS_IN_TLS = sys.version_info >= (3, 11) or aiofastnet is not None + + @pytest.fixture def secure_proxy_url(tls_certificate_pem_path: str) -> Iterator[URL]: """Return the URL of an instance of a running secure proxy. diff --git a/tests/test_run_app.py b/tests/test_run_app.py index a1cf5dd0f92..f62f15917c9 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -25,6 +25,7 @@ ServerDisconnectedError, WSCloseCode, web, + web_runner as web_runner_module, ) from aiohttp.log import access_logger from aiohttp.web_protocol import RequestHandler @@ -69,6 +70,17 @@ def skip_if_on_windows() -> None: pytest.skip("the test is not valid for Windows") +@pytest.fixture +def create_server_mock() -> Iterator[mock.AsyncMock]: + server = mock.create_autospec(asyncio.Server, spec_set=True, instance=True) + server.wait_closed.return_value = None + server.sockets = [] + create_server_mock = mock.AsyncMock(return_value=server) + + with mock.patch.object(web_runner_module, "create_server", create_server_mock): + yield create_server_mock + + @pytest.fixture def patched_loop( event_loop: asyncio.AbstractEventLoop, @@ -103,7 +115,9 @@ def f(*args: object) -> None: return f -def test_run_app_http(patched_loop: asyncio.AbstractEventLoop) -> None: +def test_run_app_http( + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock +) -> None: app = web.Application() startup_handler = mock.AsyncMock() app.on_startup.append(startup_handler) @@ -112,19 +126,35 @@ def test_run_app_http(patched_loop: asyncio.AbstractEventLoop) -> None: web.run_app(app, print=stopper(patched_loop), loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None + create_server_mock.assert_called_with( + patched_loop, + mock.ANY, + None, + 8080, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, ) startup_handler.assert_called_once_with(app) cleanup_handler.assert_called_once_with(app) -def test_run_app_close_loop(patched_loop: asyncio.AbstractEventLoop) -> None: +def test_run_app_close_loop( + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock +) -> None: app = web.Application() web.run_app(app, print=stopper(patched_loop), loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None + create_server_mock.assert_called_with( + patched_loop, + mock.ANY, + None, + 8080, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, ) assert patched_loop.is_closed() @@ -160,6 +190,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: ] mock_server_single = [ mock.call( + mock.ANY, mock.ANY, "127.0.0.1", 8080, @@ -171,6 +202,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: ] mock_server_multi = [ mock.call( + mock.ANY, mock.ANY, "127.0.0.1", 8080, @@ -180,6 +212,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: reuse_port=None, ), mock.call( + mock.ANY, mock.ANY, "192.168.1.1", 8080, @@ -191,7 +224,14 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: ] mock_server_default_8989 = [ mock.call( - mock.ANY, None, 8989, ssl=None, backlog=128, reuse_address=None, reuse_port=None + mock.ANY, + mock.ANY, + None, + 8989, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, ) ] mock_socket = mock.Mock(getsockname=lambda: ("mock-socket", 123)) @@ -203,6 +243,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {}, [ mock.call( + mock.ANY, mock.ANY, None, 8080, @@ -261,6 +302,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: }, [ mock.call( + mock.ANY, mock.ANY, "127.0.0.1", 8000, @@ -270,6 +312,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: reuse_port=None, ), mock.call( + mock.ANY, mock.ANY, "192.168.1.1", 8000, @@ -284,7 +327,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: ( "Only socket", {"sock": [mock_socket]}, - [mock.call(mock.ANY, ssl=None, sock=mock_socket, backlog=128)], + [mock.call(mock.ANY, mock.ANY, ssl=None, sock=mock_socket, backlog=128)], [], ), ( @@ -292,6 +335,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"sock": [mock_socket], "port": 8765}, [ mock.call( + mock.ANY, mock.ANY, None, 8765, @@ -300,7 +344,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: reuse_address=None, reuse_port=None, ), - mock.call(mock.ANY, sock=mock_socket, ssl=None, backlog=128), + mock.call(mock.ANY, mock.ANY, sock=mock_socket, ssl=None, backlog=128), ], [], ), @@ -309,6 +353,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"sock": [mock_socket], "host": "localhost"}, [ mock.call( + mock.ANY, mock.ANY, "localhost", 8080, @@ -317,7 +362,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: reuse_address=None, reuse_port=None, ), - mock.call(mock.ANY, sock=mock_socket, ssl=None, backlog=128), + mock.call(mock.ANY, mock.ANY, sock=mock_socket, ssl=None, backlog=128), ], [], ), @@ -326,6 +371,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"reuse_port": True}, [ mock.call( + mock.ANY, mock.ANY, None, 8080, @@ -342,6 +388,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"reuse_address": False}, [ mock.call( + mock.ANY, mock.ANY, None, 8080, @@ -358,6 +405,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"reuse_address": True, "reuse_port": True}, [ mock.call( + mock.ANY, mock.ANY, None, 8080, @@ -374,6 +422,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"port": 8989, "reuse_port": True}, [ mock.call( + mock.ANY, mock.ANY, None, 8989, @@ -390,6 +439,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"host": ("127.0.0.1", "192.168.1.1"), "reuse_port": True}, [ mock.call( + mock.ANY, mock.ANY, "127.0.0.1", 8080, @@ -399,6 +449,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: reuse_port=True, ), mock.call( + mock.ANY, mock.ANY, "192.168.1.1", 8080, @@ -419,6 +470,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: }, [ mock.call( + mock.ANY, mock.ANY, None, 8989, @@ -440,6 +492,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: }, [ mock.call( + mock.ANY, mock.ANY, "127.0.0.1", 8080, @@ -466,15 +519,18 @@ def test_run_app_mixed_bindings( # type: ignore[misc] expected_server_calls: list[mock._Call], expected_unix_server_calls: list[mock._Call], patched_loop: asyncio.AbstractEventLoop, + create_server_mock: mock.AsyncMock, ) -> None: app = web.Application() web.run_app(app, print=stopper(patched_loop), **run_app_kwargs, loop=patched_loop) assert patched_loop.create_unix_server.mock_calls == expected_unix_server_calls # type: ignore[attr-defined] - assert patched_loop.create_server.mock_calls == expected_server_calls # type: ignore[attr-defined] + assert create_server_mock.mock_calls == expected_server_calls -def test_run_app_https(patched_loop: asyncio.AbstractEventLoop) -> None: +def test_run_app_https( + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock +) -> None: app = web.Application() ssl_context = ssl.create_default_context() @@ -482,7 +538,8 @@ def test_run_app_https(patched_loop: asyncio.AbstractEventLoop) -> None: app, ssl_context=ssl_context, print=stopper(patched_loop), loop=patched_loop ) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] + create_server_mock.assert_called_with( + patched_loop, mock.ANY, None, 8443, @@ -494,7 +551,9 @@ def test_run_app_https(patched_loop: asyncio.AbstractEventLoop) -> None: def test_run_app_nondefault_host_port( - patched_loop: asyncio.AbstractEventLoop, unused_port_socket: socket.socket + patched_loop: asyncio.AbstractEventLoop, + unused_port_socket: socket.socket, + create_server_mock: mock.AsyncMock, ) -> None: port = unused_port_socket.getsockname()[1] host = "127.0.0.1" @@ -504,13 +563,22 @@ def test_run_app_nondefault_host_port( app, host=host, port=port, print=stopper(patched_loop), loop=patched_loop ) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, host, port, ssl=None, backlog=128, reuse_address=None, reuse_port=None + create_server_mock.assert_called_with( + patched_loop, + mock.ANY, + host, + port, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, ) def test_run_app_with_sock( - patched_loop: asyncio.AbstractEventLoop, unused_port_socket: socket.socket + patched_loop: asyncio.AbstractEventLoop, + unused_port_socket: socket.socket, + create_server_mock: mock.AsyncMock, ) -> None: sock = unused_port_socket app = web.Application() @@ -521,12 +589,14 @@ def test_run_app_with_sock( loop=patched_loop, ) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, sock=sock, ssl=None, backlog=128 + create_server_mock.assert_called_with( + patched_loop, mock.ANY, sock=sock, ssl=None, backlog=128 ) -def test_run_app_multiple_hosts(patched_loop: asyncio.AbstractEventLoop) -> None: +def test_run_app_multiple_hosts( + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock +) -> None: hosts = ("127.0.0.1", "127.0.0.2") app = web.Application() @@ -534,6 +604,7 @@ def test_run_app_multiple_hosts(patched_loop: asyncio.AbstractEventLoop) -> None calls = map( lambda h: mock.call( + patched_loop, mock.ANY, h, 8080, @@ -544,15 +615,24 @@ def test_run_app_multiple_hosts(patched_loop: asyncio.AbstractEventLoop) -> None ), hosts, ) - patched_loop.create_server.assert_has_calls(calls) # type: ignore[attr-defined] + create_server_mock.assert_has_calls(list(calls)) -def test_run_app_custom_backlog(patched_loop: asyncio.AbstractEventLoop) -> None: +def test_run_app_custom_backlog( + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock +) -> None: app = web.Application() web.run_app(app, backlog=10, print=stopper(patched_loop), loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, None, 8080, ssl=None, backlog=10, reuse_address=None, reuse_port=None + create_server_mock.assert_called_with( + patched_loop, + mock.ANY, + None, + 8080, + ssl=None, + backlog=10, + reuse_address=None, + reuse_port=None, ) @@ -630,7 +710,9 @@ def test_run_app_abstract_linux_socket( def test_run_app_preexisting_inet_socket( - patched_loop: asyncio.AbstractEventLoop, mocker: MockerFixture + patched_loop: asyncio.AbstractEventLoop, + mocker: MockerFixture, + create_server_mock: mock.AsyncMock, ) -> None: app = web.Application() @@ -642,15 +724,15 @@ def test_run_app_preexisting_inet_socket( printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, sock=sock, print=printer, loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, sock=sock, backlog=128, ssl=None + create_server_mock.assert_called_with( + patched_loop, mock.ANY, sock=sock, backlog=128, ssl=None ) assert f"http://127.0.0.1:{port}" in printer.call_args[0][0] @pytest.mark.skipif(not HAS_IPV6, reason="IPv6 is not available") def test_run_app_preexisting_inet6_socket( - patched_loop: asyncio.AbstractEventLoop, + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock ) -> None: app = web.Application() @@ -662,15 +744,18 @@ def test_run_app_preexisting_inet6_socket( printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, sock=sock, print=printer, loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, sock=sock, backlog=128, ssl=None + create_server_mock.assert_called_with( + patched_loop, mock.ANY, sock=sock, backlog=128, ssl=None ) assert f"http://[::1]:{port}" in printer.call_args[0][0] @skip_if_no_unix_socks def test_run_app_preexisting_unix_socket( - patched_loop: asyncio.AbstractEventLoop, unix_sockname: str, mocker: MockerFixture + patched_loop: asyncio.AbstractEventLoop, + unix_sockname: str, + mocker: MockerFixture, + create_server_mock: mock.AsyncMock, ) -> None: app = web.Application() @@ -682,14 +767,14 @@ def test_run_app_preexisting_unix_socket( printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, sock=sock, print=printer, loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, sock=sock, backlog=128, ssl=None + create_server_mock.assert_called_with( + patched_loop, mock.ANY, sock=sock, backlog=128, ssl=None ) assert f"http://unix:{unix_sockname}:" in printer.call_args[0][0] def test_run_app_multiple_preexisting_sockets( - patched_loop: asyncio.AbstractEventLoop, + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock ) -> None: app = web.Application() @@ -704,10 +789,10 @@ def test_run_app_multiple_preexisting_sockets( printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, sock=(sock1, sock2), print=printer, loop=patched_loop) - patched_loop.create_server.assert_has_calls( # type: ignore[attr-defined] + create_server_mock.assert_has_calls( [ - mock.call(mock.ANY, sock=sock1, backlog=128, ssl=None), - mock.call(mock.ANY, sock=sock2, backlog=128, ssl=None), + mock.call(patched_loop, mock.ANY, sock=sock1, backlog=128, ssl=None), + mock.call(patched_loop, mock.ANY, sock=sock2, backlog=128, ssl=None), ] ) assert f"http://127.0.0.1:{port1}" in printer.call_args[0][0] @@ -753,9 +838,9 @@ def test_sigterm() -> None: def test_startup_cleanup_signals_even_on_failure( - patched_loop: asyncio.AbstractEventLoop, + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock ) -> None: - patched_loop.create_server.side_effect = RuntimeError() # type: ignore[attr-defined] + create_server_mock.side_effect = RuntimeError() app = web.Application() startup_handler = mock.AsyncMock() @@ -770,7 +855,9 @@ def test_startup_cleanup_signals_even_on_failure( cleanup_handler.assert_called_once_with(app) -def test_run_app_coro(patched_loop: asyncio.AbstractEventLoop) -> None: +def test_run_app_coro( + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock +) -> None: startup_handler = cleanup_handler = None async def make_app() -> web.Application: @@ -784,8 +871,15 @@ async def make_app() -> web.Application: web.run_app(make_app(), print=stopper(patched_loop), loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None + create_server_mock.assert_called_with( + patched_loop, + mock.ANY, + None, + 8080, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, ) assert startup_handler is not None assert cleanup_handler is not None @@ -911,9 +1005,7 @@ async def on_startup(app: web.Application) -> None: assert task.cancelled() -def test_run_app_cancels_done_tasks( - patched_loop: asyncio.AbstractEventLoop, -) -> None: +def test_run_app_cancels_done_tasks(patched_loop: asyncio.AbstractEventLoop) -> None: app = web.Application() task = None @@ -932,9 +1024,7 @@ async def on_startup(app: web.Application) -> None: assert task.done() -def test_run_app_cancels_failed_tasks( - patched_loop: asyncio.AbstractEventLoop, -) -> None: +def test_run_app_cancels_failed_tasks(patched_loop: asyncio.AbstractEventLoop) -> None: app = web.Application() task = None @@ -1031,9 +1121,7 @@ async def init() -> web.Application: assert count == 3 -def test_run_app_raises_exception( - patched_loop: asyncio.AbstractEventLoop, -) -> None: +def test_run_app_raises_exception(patched_loop: asyncio.AbstractEventLoop) -> None: async def context(app: web.Application) -> AsyncIterator[None]: raise RuntimeError("foo") yield # type: ignore[unreachable] # pragma: no cover diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py index c4b7b19e8b7..c191c605890 100644 --- a/tests/test_web_runner.py +++ b/tests/test_web_runner.py @@ -8,7 +8,7 @@ import pytest -from aiohttp import web +from aiohttp import web, web_runner as web_runner_module from aiohttp.abc import AbstractAccessLogger from aiohttp.test_utils import REUSE_ADDRESS from aiohttp.web_log import AccessLogger @@ -265,16 +265,16 @@ async def test_tcpsite_default_host(make_runner: _RunnerMaker) -> None: site = web.TCPSite(runner) assert site.name == "http://0.0.0.0:8080" - m = mock.create_autospec(asyncio.AbstractEventLoop, spec_set=True, instance=True) - m.create_server.return_value = mock.create_autospec(asyncio.Server, spec_set=True) - with mock.patch( - "asyncio.get_running_loop", autospec=True, spec_set=True, return_value=m - ): + create_server = mock.AsyncMock( + return_value=mock.create_autospec(asyncio.Server, spec_set=True) + ) + + with mock.patch.object(web_runner_module, "create_server", create_server): await site.start() - m.create_server.assert_called_once() - args, kwargs = m.create_server.call_args - assert args == (runner.server, None, 8080) + create_server.assert_called_once() + args, kwargs = create_server.call_args + assert args == (asyncio.get_running_loop(), runner.server, None, 8080) async def test_tcpsite_empty_str_host(make_runner: _RunnerMaker) -> None: diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index e4daf828fcd..1304eb92b0c 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -13,7 +13,7 @@ from pytest_aiohttp import AiohttpClient, AiohttpServer import aiohttp -from aiohttp import web +from aiohttp import web, web_fileresponse as web_fileresponse_module from aiohttp.compression_utils import ZLibBackend from aiohttp.typedefs import PathLike from aiohttp.web_fileresponse import NOSENDFILE @@ -74,14 +74,13 @@ async def sender(request: SubRequest) -> AsyncIterator[_Sender]: def maker(path: PathLike, chunk_size: int = 256 * 1024) -> web.FileResponse: ret = web.FileResponse(path, chunk_size=chunk_size) - rloop = asyncio.get_running_loop() - is_patched = rloop.sendfile is sendfile_mock + is_patched = web_fileresponse_module.sendfile is sendfile_mock assert is_patched if request.param == "no_sendfile" else not is_patched return ret if request.param == "no_sendfile": with mock.patch.object( - asyncio.get_running_loop(), + web_fileresponse_module, "sendfile", autospec=True, spec_set=True,