diff --git a/CHANGES/12234.bugfix.rst b/CHANGES/12234.bugfix.rst new file mode 100644 index 00000000000..64bcfa24f69 --- /dev/null +++ b/CHANGES/12234.bugfix.rst @@ -0,0 +1,2 @@ +Fixed zstd decompression failing with ``ClientPayloadError`` when the server +sends a response as multiple zstd frames -- by :user:`josu-moreno`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index c3c16f82eee..4a3934e7df7 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -214,6 +214,7 @@ Jordan Borean Josep Cugat Josh Junon Joshu Coats +Josu Moreno Julia Tsemusheva Julien Duponchelle Jungkook Park diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index 0bc4a30d8ed..2a8818c4220 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -330,6 +330,7 @@ def __init__( "Please install `backports.zstd` module" ) self._obj = ZstdDecompressor() + self._pending_unused_data: bytes | None = None super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) def decompress_sync( @@ -342,7 +343,33 @@ def decompress_sync( if max_length == ZLIB_MAX_LENGTH_UNLIMITED else max_length ) - return self._obj.decompress(data, zstd_max_length) + if self._pending_unused_data is not None: + data = self._pending_unused_data + data + self._pending_unused_data = None + result = self._obj.decompress(data, zstd_max_length) + + # Handle multi-frame zstd streams. + # https://datatracker.ietf.org/doc/html/rfc8878#section-3.1.1 + # ZstdDecompressor handles one frame only. When a frame ends, + # eof becomes True and any trailing data goes to unused_data. + # We create a fresh decompressor to continue with the next frame. + while self._obj.eof and self._obj.unused_data: + unused_data = self._obj.unused_data + self._obj = ZstdDecompressor() + if zstd_max_length != ZSTD_MAX_LENGTH_UNLIMITED: + zstd_max_length -= len(result) + if zstd_max_length <= 0: + self._pending_unused_data = unused_data + break + result += self._obj.decompress(unused_data, zstd_max_length) + + # Frame ended exactly at chunk boundary — no unused_data, but the + # next feed_data() call would fail on the spent decompressor. + # Prepare a fresh one for the next chunk. + if self._obj.eof: + self._obj = ZstdDecompressor() + + return result def flush(self) -> bytes: return b"" diff --git a/tests/test_compression_utils.py b/tests/test_compression_utils.py index fdaf91b36a0..3362b8feed0 100644 --- a/tests/test_compression_utils.py +++ b/tests/test_compression_utils.py @@ -1,8 +1,23 @@ """Tests for compression utils.""" +import sys + import pytest -from aiohttp.compression_utils import ZLibBackend, ZLibCompressor, ZLibDecompressor +from aiohttp.compression_utils import ( + ZLibBackend, + ZLibCompressor, + ZLibDecompressor, + ZSTDDecompressor, +) + +try: + if sys.version_info >= (3, 14): + import compression.zstd as zstandard # noqa: I900 + else: + import backports.zstd as zstandard +except ImportError: # pragma: no cover + zstandard = None # type: ignore[assignment] @pytest.mark.usefixtures("parametrize_zlib_backend") @@ -33,3 +48,42 @@ async def test_compression_round_trip_in_event_loop() -> None: compressed_data = await compressor.compress(data) + compressor.flush() decompressed_data = await decompressor.decompress(compressed_data) assert data == decompressed_data + + +@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") +def test_zstd_multi_frame_unlimited() -> None: + d = ZSTDDecompressor() + frame1 = zstandard.compress(b"AAAA") + frame2 = zstandard.compress(b"BBBB") + result = d.decompress_sync(frame1 + frame2) + assert result == b"AAAABBBB" + + +@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") +def test_zstd_multi_frame_max_length_partial() -> None: + d = ZSTDDecompressor() + frame1 = zstandard.compress(b"AAAA") + frame2 = zstandard.compress(b"BBBB") + result = d.decompress_sync(frame1 + frame2, max_length=6) + assert result == b"AAAABB" + + +@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") +def test_zstd_multi_frame_max_length_exhausted() -> None: + d = ZSTDDecompressor() + frame1 = zstandard.compress(b"AAAA") + frame2 = zstandard.compress(b"BBBB") + result = d.decompress_sync(frame1 + frame2, max_length=4) + assert result == b"AAAA" + + +@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") +def test_zstd_multi_frame_max_length_exhausted_preserves_unused_data() -> None: + d = ZSTDDecompressor() + frame1 = zstandard.compress(b"AAAA") + frame2 = zstandard.compress(b"BBBB") + frame3 = zstandard.compress(b"CCCC") + result1 = d.decompress_sync(frame1 + frame2, max_length=4) + assert result1 == b"AAAA" + result2 = d.decompress_sync(frame3) + assert result2 == b"BBBBCCCC" diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 1bf80d271c3..6e877aaacd2 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -2081,6 +2081,79 @@ async def test_http_payload_zstandard(self, protocol: BaseProtocol) -> None: assert b"zstd data" == out._buffer[0] assert out.is_eof() + @pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") + async def test_http_payload_zstandard_multi_frame( + self, protocol: BaseProtocol + ) -> None: + frame1 = zstandard.compress(b"first") + frame2 = zstandard.compress(b"second") + payload = frame1 + frame2 + out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + p = HttpPayloadParser( + out, + length=len(payload), + compression="zstd", + headers_parser=HeadersParser(), + ) + p.feed_data(payload) + assert b"firstsecond" == b"".join(out._buffer) + assert out.is_eof() + + @pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") + async def test_http_payload_zstandard_multi_frame_chunked( + self, protocol: BaseProtocol + ) -> None: + frame1 = zstandard.compress(b"chunk1") + frame2 = zstandard.compress(b"chunk2") + out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + p = HttpPayloadParser( + out, + length=len(frame1) + len(frame2), + compression="zstd", + headers_parser=HeadersParser(), + ) + p.feed_data(frame1) + p.feed_data(frame2) + assert b"chunk1chunk2" == b"".join(out._buffer) + assert out.is_eof() + + @pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") + async def test_http_payload_zstandard_frame_split_mid_chunk( + self, protocol: BaseProtocol + ) -> None: + frame1 = zstandard.compress(b"AAAA") + frame2 = zstandard.compress(b"BBBB") + combined = frame1 + frame2 + split_point = len(frame1) + 3 # 3 bytes into frame2 + out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + p = HttpPayloadParser( + out, + length=len(combined), + compression="zstd", + headers_parser=HeadersParser(), + ) + p.feed_data(combined[:split_point]) + p.feed_data(combined[split_point:]) + assert b"AAAABBBB" == b"".join(out._buffer) + assert out.is_eof() + + @pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") + async def test_http_payload_zstandard_many_small_frames( + self, protocol: BaseProtocol + ) -> None: + parts = [f"part{i}".encode() for i in range(10)] + payload = b"".join(zstandard.compress(p) for p in parts) + out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + p = HttpPayloadParser( + out, + length=len(payload), + compression="zstd", + headers_parser=HeadersParser(), + ) + p.feed_data(payload) + assert b"".join(parts) == b"".join(out._buffer) + assert out.is_eof() + class TestDeflateBuffer: async def test_feed_data(self, protocol: BaseProtocol) -> None: