Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/12234.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed zstd decompression failing with ``ClientPayloadError`` when the server
sends a response as multiple zstd frames -- by :user:`josu-moreno`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ Jordan Borean
Josep Cugat
Josh Junon
Joshu Coats
Josu Moreno
Julia Tsemusheva
Julien Duponchelle
Jungkook Park
Expand Down
29 changes: 28 additions & 1 deletion aiohttp/compression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def __init__(
"Please install `backports.zstd` module"
)
self._obj = ZstdDecompressor()
self._pending_unused_data: bytes | None = None
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)

def decompress_sync(
Expand All @@ -342,7 +343,33 @@ def decompress_sync(
if max_length == ZLIB_MAX_LENGTH_UNLIMITED
else max_length
)
return self._obj.decompress(data, zstd_max_length)
if self._pending_unused_data is not None:
data = self._pending_unused_data + data
self._pending_unused_data = None
result = self._obj.decompress(data, zstd_max_length)

# Handle multi-frame zstd streams.
# https://datatracker.ietf.org/doc/html/rfc8878#section-3.1.1
# ZstdDecompressor handles one frame only. When a frame ends,
# eof becomes True and any trailing data goes to unused_data.
# We create a fresh decompressor to continue with the next frame.
while self._obj.eof and self._obj.unused_data:
unused_data = self._obj.unused_data
self._obj = ZstdDecompressor()
if zstd_max_length != ZSTD_MAX_LENGTH_UNLIMITED:
zstd_max_length -= len(result)
if zstd_max_length <= 0:
self._pending_unused_data = unused_data
break
result += self._obj.decompress(unused_data, zstd_max_length)

# Frame ended exactly at chunk boundary — no unused_data, but the
# next feed_data() call would fail on the spent decompressor.
# Prepare a fresh one for the next chunk.
if self._obj.eof:
self._obj = ZstdDecompressor()

return result

def flush(self) -> bytes:
return b""
56 changes: 55 additions & 1 deletion tests/test_compression_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
"""Tests for compression utils."""

import sys

import pytest

from aiohttp.compression_utils import ZLibBackend, ZLibCompressor, ZLibDecompressor
from aiohttp.compression_utils import (
ZLibBackend,
ZLibCompressor,
ZLibDecompressor,
ZSTDDecompressor,
)

try:
if sys.version_info >= (3, 14):
import compression.zstd as zstandard # noqa: I900
else:
import backports.zstd as zstandard
except ImportError: # pragma: no cover
zstandard = None # type: ignore[assignment]


@pytest.mark.usefixtures("parametrize_zlib_backend")
Expand Down Expand Up @@ -33,3 +48,42 @@ async def test_compression_round_trip_in_event_loop() -> None:
compressed_data = await compressor.compress(data) + compressor.flush()
decompressed_data = await decompressor.decompress(compressed_data)
assert data == decompressed_data


@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
def test_zstd_multi_frame_unlimited() -> None:
d = ZSTDDecompressor()
frame1 = zstandard.compress(b"AAAA")
frame2 = zstandard.compress(b"BBBB")
result = d.decompress_sync(frame1 + frame2)
assert result == b"AAAABBBB"


@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
def test_zstd_multi_frame_max_length_partial() -> None:
d = ZSTDDecompressor()
frame1 = zstandard.compress(b"AAAA")
frame2 = zstandard.compress(b"BBBB")
result = d.decompress_sync(frame1 + frame2, max_length=6)
assert result == b"AAAABB"


@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
def test_zstd_multi_frame_max_length_exhausted() -> None:
d = ZSTDDecompressor()
frame1 = zstandard.compress(b"AAAA")
frame2 = zstandard.compress(b"BBBB")
result = d.decompress_sync(frame1 + frame2, max_length=4)
assert result == b"AAAA"


@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
def test_zstd_multi_frame_max_length_exhausted_preserves_unused_data() -> None:
d = ZSTDDecompressor()
frame1 = zstandard.compress(b"AAAA")
frame2 = zstandard.compress(b"BBBB")
frame3 = zstandard.compress(b"CCCC")
result1 = d.decompress_sync(frame1 + frame2, max_length=4)
assert result1 == b"AAAA"
result2 = d.decompress_sync(frame3)
assert result2 == b"BBBBCCCC"
73 changes: 73 additions & 0 deletions tests/test_http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2081,6 +2081,79 @@ async def test_http_payload_zstandard(self, protocol: BaseProtocol) -> None:
assert b"zstd data" == out._buffer[0]
assert out.is_eof()

@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
async def test_http_payload_zstandard_multi_frame(
self, protocol: BaseProtocol
) -> None:
frame1 = zstandard.compress(b"first")
frame2 = zstandard.compress(b"second")
payload = frame1 + frame2
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(
out,
length=len(payload),
compression="zstd",
headers_parser=HeadersParser(),
)
p.feed_data(payload)
assert b"firstsecond" == b"".join(out._buffer)
assert out.is_eof()

@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
async def test_http_payload_zstandard_multi_frame_chunked(
self, protocol: BaseProtocol
) -> None:
frame1 = zstandard.compress(b"chunk1")
frame2 = zstandard.compress(b"chunk2")
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(
out,
length=len(frame1) + len(frame2),
compression="zstd",
headers_parser=HeadersParser(),
)
p.feed_data(frame1)
p.feed_data(frame2)
assert b"chunk1chunk2" == b"".join(out._buffer)
assert out.is_eof()

@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
async def test_http_payload_zstandard_frame_split_mid_chunk(
self, protocol: BaseProtocol
) -> None:
frame1 = zstandard.compress(b"AAAA")
frame2 = zstandard.compress(b"BBBB")
combined = frame1 + frame2
split_point = len(frame1) + 3 # 3 bytes into frame2
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(
out,
length=len(combined),
compression="zstd",
headers_parser=HeadersParser(),
)
p.feed_data(combined[:split_point])
p.feed_data(combined[split_point:])
assert b"AAAABBBB" == b"".join(out._buffer)
assert out.is_eof()

@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
async def test_http_payload_zstandard_many_small_frames(
self, protocol: BaseProtocol
) -> None:
parts = [f"part{i}".encode() for i in range(10)]
payload = b"".join(zstandard.compress(p) for p in parts)
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
p = HttpPayloadParser(
out,
length=len(payload),
compression="zstd",
headers_parser=HeadersParser(),
)
p.feed_data(payload)
assert b"".join(parts) == b"".join(out._buffer)
assert out.is_eof()


class TestDeflateBuffer:
async def test_feed_data(self, protocol: BaseProtocol) -> None:
Expand Down
Loading