Skip to content

Commit 8a74257

Browse files
Restrict multipart header sizes (aio-libs#12208) (aio-libs#12228)
(cherry picked from commit 5fe9dfb)
1 parent 53b35a2 commit 8a74257

8 files changed

Lines changed: 109 additions & 20 deletions

File tree

aiohttp/multipart.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from .helpers import CHAR, TOKEN, parse_mimetype, reify
4343
from .http import HeadersParser
44+
from .http_exceptions import BadHttpMessage
4445
from .log import internal_logger
4546
from .payload import (
4647
JsonPayload,
@@ -658,7 +659,14 @@ class MultipartReader:
658659
#: Body part reader class for non multipart/* content types.
659660
part_reader_cls = BodyPartReader
660661

661-
def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
662+
def __init__(
663+
self,
664+
headers: Mapping[str, str],
665+
content: StreamReader,
666+
*,
667+
max_field_size: int = 8190,
668+
max_headers: int = 128,
669+
) -> None:
662670
self._mimetype = parse_mimetype(headers[CONTENT_TYPE])
663671
assert self._mimetype.type == "multipart", "multipart/* content type expected"
664672
if "boundary" not in self._mimetype.parameters:
@@ -669,8 +677,10 @@ def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
669677
self.headers = headers
670678
self._boundary = ("--" + self._get_boundary()).encode()
671679
self._content = content
672-
self._default_charset: Optional[str] = None
673-
self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None
680+
self._default_charset: str | None = None
681+
self._last_part: MultipartReader | BodyPartReader | None = None
682+
self._max_field_size = max_field_size
683+
self._max_headers = max_headers
674684
self._at_eof = False
675685
self._at_bof = True
676686
self._unread: List[bytes] = []
@@ -770,7 +780,12 @@ def _get_part_reader(
770780
if mimetype.type == "multipart":
771781
if self.multipart_reader_cls is None:
772782
return type(self)(headers, self._content)
773-
return self.multipart_reader_cls(headers, self._content)
783+
return self.multipart_reader_cls(
784+
headers,
785+
self._content,
786+
max_field_size=self._max_field_size,
787+
max_headers=self._max_headers,
788+
)
774789
else:
775790
return self.part_reader_cls(
776791
self._boundary,
@@ -832,12 +847,14 @@ async def _read_boundary(self) -> None:
832847
async def _read_headers(self) -> "CIMultiDictProxy[str]":
833848
lines = []
834849
while True:
835-
chunk = await self._content.readline()
850+
chunk = await self._content.readline(max_line_length=self._max_field_size)
836851
chunk = chunk.rstrip(b"\r\n")
837852
lines.append(chunk)
838853
if not chunk:
839854
break
840-
parser = HeadersParser()
855+
if len(lines) > self._max_headers:
856+
raise BadHttpMessage("Too many headers received")
857+
parser = HeadersParser(max_field_size=self._max_field_size)
841858
headers, raw_headers = parser.parse_headers(lines)
842859
return headers
843860

aiohttp/streams.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
set_exception,
2222
set_result,
2323
)
24+
from .http_exceptions import LineTooLong
2425
from .log import internal_logger
2526

2627
__all__ = (
@@ -372,10 +373,12 @@ async def _wait(self, func_name: str) -> None:
372373
finally:
373374
self._waiter = None
374375

375-
async def readline(self) -> bytes:
376-
return await self.readuntil()
376+
async def readline(self, *, max_line_length: Optional[int] = None) -> bytes:
377+
return await self.readuntil(max_size=max_line_length)
377378

378-
async def readuntil(self, separator: bytes = b"\n") -> bytes:
379+
async def readuntil(
380+
self, separator: bytes = b"\n", *, max_size: Optional[int] = None
381+
) -> bytes:
379382
seplen = len(separator)
380383
if seplen == 0:
381384
raise ValueError("Separator should be at least one-byte string")
@@ -386,6 +389,7 @@ async def readuntil(self, separator: bytes = b"\n") -> bytes:
386389
chunk = b""
387390
chunk_size = 0
388391
not_enough = True
392+
max_size = max_size or self._high_water
389393

390394
while not_enough:
391395
while self._buffer and not_enough:
@@ -400,8 +404,8 @@ async def readuntil(self, separator: bytes = b"\n") -> bytes:
400404
if ichar:
401405
not_enough = False
402406

403-
if chunk_size > self._high_water:
404-
raise ValueError("Chunk too big")
407+
if chunk_size > max_size:
408+
raise LineTooLong(chunk[:100] + b"...", max_size)
405409

406410
if self._eof:
407411
break
@@ -622,7 +626,7 @@ async def wait_eof(self) -> None:
622626
def feed_data(self, data: bytes, n: int = 0) -> None:
623627
pass
624628

625-
async def readline(self) -> bytes:
629+
async def readline(self, *, max_line_length: Optional[int] = None) -> bytes:
626630
return b""
627631

628632
async def read(self, n: int = -1) -> bytes:

aiohttp/test_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,9 @@ def make_mocked_request(
729729

730730
if protocol is sentinel:
731731
protocol = mock.Mock()
732+
protocol.max_field_size = 8190
733+
protocol.max_line_length = 8190
734+
protocol.max_headers = 128
732735
protocol.transport = transport
733736
type(protocol).peername = mock.PropertyMock(
734737
return_value=transport.get_extra_info("peername")

aiohttp/web_protocol.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ class RequestHandler(BaseProtocol):
142142
"""
143143

144144
__slots__ = (
145+
"max_field_size",
146+
"max_headers",
147+
"max_line_size",
145148
"_request_count",
146149
"_keepalive",
147150
"_manager",
@@ -205,6 +208,10 @@ def __init__(
205208
self._request_handler: Optional[_RequestHandler] = manager.request_handler
206209
self._request_factory: Optional[_RequestFactory] = manager.request_factory
207210

211+
self.max_line_size = max_line_size
212+
self.max_headers = max_headers
213+
self.max_field_size = max_field_size
214+
208215
self._tcp_keepalive = tcp_keepalive
209216
# placeholder to be replaced on keepalive timeout setup
210217
self._next_keepalive_close_time = 0.0

aiohttp/web_request.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,12 @@ async def json(self, *, loads: JSONDecoder = DEFAULT_JSON_DECODER) -> Any:
696696

697697
async def multipart(self) -> MultipartReader:
698698
"""Return async iterator to process BODY as multipart."""
699-
return MultipartReader(self._headers, self._payload)
699+
return MultipartReader(
700+
self._headers,
701+
self._payload,
702+
max_field_size=self._protocol.max_field_size,
703+
max_headers=self._protocol.max_headers,
704+
)
700705

701706
async def post(self) -> "MultiDictProxy[Union[str, bytes, FileField]]":
702707
"""Return POST parameters."""

tests/test_multipart.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import pathlib
55
import sys
6+
from typing import Optional
67
from unittest import mock
78

89
import pytest
@@ -85,7 +86,7 @@ async def read(self, size=None):
8586
def at_eof(self):
8687
return self.content.tell() == len(self.content.getbuffer())
8788

88-
async def readline(self):
89+
async def readline(self, *, max_line_length: Optional[int] = None) -> bytes:
8990
return self.content.readline()
9091

9192
def unread_data(self, data):
@@ -856,7 +857,7 @@ async def read(self, size=None) -> bytes:
856857
def at_eof(self) -> bool:
857858
return not self.content
858859

859-
async def readline(self) -> bytes:
860+
async def readline(self, *, max_line_length: int | None = None) -> bytes:
860861
line = b""
861862
while self.content and b"\n" not in line:
862863
line += self.content.pop(0)

tests/test_streams.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from re_assert import Matches
1313

1414
from aiohttp import streams
15+
from aiohttp.http_exceptions import LineTooLong
1516

1617
DATA = b"line1\nline2\nline3\n"
1718

@@ -325,7 +326,7 @@ async def test_readline_limit_with_existing_data(self) -> None:
325326
stream.feed_data(b"li")
326327
stream.feed_data(b"ne1\nline2\n")
327328

328-
with pytest.raises(ValueError):
329+
with pytest.raises(LineTooLong):
329330
await stream.readline()
330331
# The buffer should contain the remaining data after exception
331332
stream.feed_eof()
@@ -346,7 +347,7 @@ def cb():
346347

347348
loop.call_soon(cb)
348349

349-
with pytest.raises(ValueError):
350+
with pytest.raises(LineTooLong):
350351
await stream.readline()
351352
data = await stream.read()
352353
assert b"chunk3\n" == data
@@ -436,7 +437,7 @@ async def test_readuntil_limit_with_existing_data(self, separator: bytes) -> Non
436437
stream.feed_data(b"li")
437438
stream.feed_data(b"ne1" + separator + b"line2" + separator)
438439

439-
with pytest.raises(ValueError):
440+
with pytest.raises(LineTooLong):
440441
await stream.readuntil(separator)
441442
# The buffer should contain the remaining data after exception
442443
stream.feed_eof()
@@ -458,7 +459,7 @@ def cb():
458459

459460
loop.call_soon(cb)
460461

461-
with pytest.raises(ValueError, match="Chunk too big"):
462+
with pytest.raises(LineTooLong):
462463
await stream.readuntil(separator)
463464
data = await stream.read()
464465
assert b"chunk3#" == data

tests/test_web_request.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from aiohttp import HttpVersion
1515
from aiohttp.base_protocol import BaseProtocol
16+
from aiohttp.http_exceptions import BadHttpMessage, LineTooLong
1617
from aiohttp.http_parser import RawRequestMessage
1718
from aiohttp.streams import StreamReader
1819
from aiohttp.test_utils import make_mocked_request
@@ -896,7 +897,57 @@ async def test_multipart_formdata_file(protocol: BaseProtocol) -> None:
896897
result["a_file"].file.close()
897898

898899

899-
async def test_make_too_big_request_limit_None(protocol) -> None:
900+
async def test_multipart_formdata_headers_too_many(protocol: BaseProtocol) -> None:
901+
many = b"".join(f"X-{i}: a\r\n".encode() for i in range(130))
902+
body = (
903+
b"--b\r\n"
904+
b'Content-Disposition: form-data; name="a"\r\n' + many + b"\r\n1\r\n"
905+
b"--b--\r\n"
906+
)
907+
content_type = "multipart/form-data; boundary=b"
908+
payload = StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
909+
payload.feed_data(body)
910+
payload.feed_eof()
911+
req = make_mocked_request(
912+
"POST",
913+
"/",
914+
headers={"CONTENT-TYPE": content_type},
915+
payload=payload,
916+
)
917+
918+
with pytest.raises(BadHttpMessage, match="Too many headers received"):
919+
await req.post()
920+
921+
922+
async def test_multipart_formdata_header_too_long(protocol: BaseProtocol) -> None:
923+
k = b"t" * 4100
924+
body = (
925+
b"--b\r\n"
926+
b'Content-Disposition: form-data; name="a"\r\n'
927+
+ k
928+
+ b":"
929+
+ k
930+
+ b"\r\n"
931+
+ b"\r\n1\r\n"
932+
b"--b--\r\n"
933+
)
934+
content_type = "multipart/form-data; boundary=b"
935+
payload = StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
936+
payload.feed_data(body)
937+
payload.feed_eof()
938+
req = make_mocked_request(
939+
"POST",
940+
"/",
941+
headers={"CONTENT-TYPE": content_type},
942+
payload=payload,
943+
)
944+
945+
match = "400, message:\n Got more than 8190 bytes when reading"
946+
with pytest.raises(LineTooLong, match=match):
947+
await req.post()
948+
949+
950+
async def test_make_too_big_request_limit_None(protocol: BaseProtocol) -> None:
900951
payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop())
901952
large_file = 1024**2 * b"x"
902953
too_large_file = large_file + b"x"

0 commit comments

Comments
 (0)