Skip to content

Commit bff9e72

Browse files
fix: harden input validation, XML escaping, and streaming safety
- Add XML escaping to all user-controlled values in xml_responses.py - Add gzip decompression size limit to prevent decompression bombs - Harden chunked decoder: buffer limits, chunk size validation, error on truncation - Cache KEK in Settings via PrivateAttr to avoid per-request SHA256 - Wrap range header parsing in try/except for malformed input - Safe int() parsing for content-length headers across 3 locations - Add usedforsecurity=False to all MD5 calls for FIPS compliance - Remove dead code and duplicate imports in upload_part.py - Fix over-indentation in upload_part SHA256 mismatch block
1 parent 52e1488 commit bff9e72

11 files changed

Lines changed: 127 additions & 69 deletions

File tree

s3proxy/config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import hashlib
44

5-
from pydantic import Field, field_validator
5+
from pydantic import Field, PrivateAttr
66
from pydantic_settings import BaseSettings, SettingsConfigDict
77

88

@@ -49,16 +49,16 @@ class Settings(BaseSettings):
4949
# Logging
5050
log_level: str = Field(default="INFO", description="Log level (DEBUG, INFO, WARNING, ERROR)")
5151

52-
@field_validator("encrypt_key")
53-
@classmethod
54-
def hash_encrypt_key(cls, v: str) -> str:
55-
"""Store the raw key - we'll hash it when needed."""
56-
return v
52+
# Cached KEK derived from encrypt_key (computed once in model_post_init)
53+
_kek: bytes = PrivateAttr()
54+
55+
def model_post_init(self, __context: object) -> None:
56+
self._kek = hashlib.sha256(self.encrypt_key.encode()).digest()
5757

5858
@property
5959
def kek(self) -> bytes:
6060
"""Get the 32-byte Key Encryption Key (SHA256 of encrypt_key)."""
61-
return hashlib.sha256(self.encrypt_key.encode()).digest()
61+
return self._kek
6262

6363
@property
6464
def s3_endpoint(self) -> str:

s3proxy/handlers/base.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,18 @@ def _parse_range(self, header: str, size: int) -> tuple[int, int]:
9898
if not header.startswith("bytes="):
9999
raise S3Error.invalid_range("Invalid range header format")
100100
spec = header[6:]
101-
if spec.startswith("-"):
102-
start = max(0, size - int(spec[1:]))
103-
end = size - 1
104-
elif spec.endswith("-"):
105-
start = int(spec[:-1])
106-
end = size - 1
107-
else:
108-
parts = spec.split("-")
109-
start, end = int(parts[0]), min(int(parts[1]), size - 1)
101+
try:
102+
if spec.startswith("-"):
103+
start = max(0, size - int(spec[1:]))
104+
end = size - 1
105+
elif spec.endswith("-"):
106+
start = int(spec[:-1])
107+
end = size - 1
108+
else:
109+
parts = spec.split("-")
110+
start, end = int(parts[0]), min(int(parts[1]), size - 1)
111+
except (ValueError, IndexError):
112+
raise S3Error.invalid_range("Invalid range header format")
110113
if start > end or start >= size:
111114
raise S3Error.invalid_range("Range not satisfiable")
112115
return start, end
@@ -117,7 +120,10 @@ def _parse_copy_source_range(
117120
if not range_header:
118121
return 0, total_size - 1
119122
range_str = range_header.replace("bytes=", "")
120-
start, end = map(int, range_str.split("-"))
123+
try:
124+
start, end = map(int, range_str.split("-"))
125+
except (ValueError, TypeError):
126+
raise S3Error.invalid_range("Invalid copy source range format")
121127
return start, end
122128

123129
def _get_effective_etag(self, metadata: dict, fallback_etag: str) -> str:

s3proxy/handlers/multipart/copy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ async def handle_upload_part_copy(self, request: Request, creds: S3Credentials)
5050
ciphertext = crypto.encrypt_part(plaintext, state.dek, upload_id, part_num)
5151
resp = await client.upload_part(bucket, key, upload_id, part_num, ciphertext)
5252

53-
body_md5 = hashlib.md5(plaintext).hexdigest()
53+
body_md5 = hashlib.md5(plaintext, usedforsecurity=False).hexdigest()
5454
await self.multipart_manager.add_part(
5555
bucket, key, upload_id,
5656
PartMetadata(

s3proxy/handlers/multipart/lifecycle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ async def handle_complete_multipart_upload(
184184
)
185185

186186
location = f"{self.settings.s3_endpoint}/{bucket}/{key}"
187-
etag = hashlib.md5(str(state.total_plaintext_size).encode()).hexdigest()
187+
etag = hashlib.md5(str(state.total_plaintext_size).encode(), usedforsecurity=False).hexdigest()
188188

189189
return Response(
190190
content=xml_responses.complete_multipart(location, bucket, key, etag),

s3proxy/handlers/multipart/upload_part.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import time
88
from collections import deque
99
from collections.abc import AsyncIterator
10-
from typing import TYPE_CHECKING, NoReturn
10+
from typing import NoReturn
1111

1212
import structlog
1313
from botocore.exceptions import ClientError
@@ -23,12 +23,9 @@
2323
PartMetadata,
2424
StateMissingError,
2525
)
26-
from ...streaming import decode_aws_chunked, decode_aws_chunked_stream
26+
from ...streaming import decode_aws_chunked_stream
2727
from ..base import BaseHandler
2828

29-
if TYPE_CHECKING:
30-
from collections.abc import AsyncIterator
31-
3229
logger: BoundLogger = structlog.get_logger(__name__)
3330

3431
# Limit concurrent internal part uploads to bound memory usage
@@ -49,7 +46,10 @@ async def handle_upload_part(self, request: Request, creds: S3Credentials) -> Re
4946
# Parse request info
5047
content_encoding = request.headers.get("content-encoding", "")
5148
content_sha = request.headers.get("x-amz-content-sha256", "")
52-
content_length = int(request.headers.get("content-length", "0"))
49+
try:
50+
content_length = int(request.headers.get("content-length", "0"))
51+
except ValueError:
52+
content_length = 0
5353

5454
upload_start_time = time.monotonic()
5555
logger.info(
@@ -94,14 +94,14 @@ async def handle_upload_part(self, request: Request, creds: S3Credentials) -> Re
9494

9595
# Late signature verification for large signed uploads
9696
if is_large_signed and content_sha and result["computed_sha256"] != content_sha:
97-
logger.warning(
98-
"UPLOAD_PART_SHA256_MISMATCH",
99-
bucket=bucket, key=key, part_num=part_num,
100-
expected=content_sha, computed=result["computed_sha256"],
101-
)
102-
raise S3Error.signature_does_not_match(
103-
"Signature verification failed"
104-
)
97+
logger.warning(
98+
"UPLOAD_PART_SHA256_MISMATCH",
99+
bucket=bucket, key=key, part_num=part_num,
100+
expected=content_sha, computed=result["computed_sha256"],
101+
)
102+
raise S3Error.signature_does_not_match(
103+
"Signature verification failed"
104+
)
105105

106106
upload_duration = time.monotonic() - upload_start_time
107107
logger.info(
@@ -156,7 +156,7 @@ async def _stream_and_upload(
156156
# Initialize state
157157
buffer_chunks: deque[bytes] = deque()
158158
buffer_size = 0
159-
md5_hash = hashlib.md5()
159+
md5_hash = hashlib.md5(usedforsecurity=False)
160160
sha256_hash = hashlib.sha256()
161161
total_plaintext_size = 0
162162
total_ciphertext_size = 0
@@ -282,8 +282,6 @@ async def _get_stream_source(
282282
content_length_mb=f"{content_length / 1024 / 1024:.2f}MB",
283283
)
284284
body = await request.body()
285-
if needs_chunked_decode:
286-
body = decode_aws_chunked(body)
287285

288286
async def body_iter():
289287
yield body

s3proxy/handlers/objects/misc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ async def handle_head_object(self, request: Request, creds: S3Credentials) -> Re
5858
"Content-Length": str(meta.total_plaintext_size),
5959
"Content-Type": resp.get("ContentType", "application/octet-stream"),
6060
"ETag": f'"{hashlib.md5(
61-
str(meta.total_plaintext_size).encode()
61+
str(meta.total_plaintext_size).encode(),
62+
usedforsecurity=False,
6263
).hexdigest()}"',
6364
**extra_headers,
6465
}
@@ -261,7 +262,7 @@ async def _copy_encrypted(
261262

262263
# Re-encrypt
263264
encrypted = crypto.encrypt_object(plaintext, self.settings.kek)
264-
etag = hashlib.md5(plaintext).hexdigest()
265+
etag = hashlib.md5(plaintext, usedforsecurity=False).hexdigest()
265266

266267
# Build destination metadata
267268
dest_metadata = {

s3proxy/handlers/objects/put.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ async def handle_put_object(self, request: Request, creds: S3Credentials) -> Res
5858
expires = request.headers.get("expires")
5959
tagging = request.headers.get("x-amz-tagging")
6060

61-
content_length = int(request.headers.get("content-length", "0"))
61+
try:
62+
content_length = int(request.headers.get("content-length", "0"))
63+
except ValueError:
64+
content_length = 0
6265
is_unsigned = content_sha == "UNSIGNED-PAYLOAD"
6366
is_streaming_sig = content_sha.startswith("STREAMING-")
6467
needs_chunked_decode = "aws-chunked" in content_encoding or is_streaming_sig
@@ -122,7 +125,7 @@ async def _put_buffered(
122125
plaintext_mb=round(len(body) / 1024 / 1024, 2),
123126
ciphertext_mb=round(len(encrypted.ciphertext) / 1024 / 1024, 2),
124127
)
125-
etag = hashlib.md5(body).hexdigest()
128+
etag = hashlib.md5(body, usedforsecurity=False).hexdigest()
126129

127130
await client.put_object(
128131
bucket, key, encrypted.ciphertext,
@@ -164,7 +167,7 @@ async def _put_streaming(
164167
parts_complete: list[dict[str, Any]] = []
165168
total_plaintext_size = 0
166169
part_num = 0
167-
md5_hash = hashlib.md5()
170+
md5_hash = hashlib.md5(usedforsecurity=False)
168171
sha256_hash = hashlib.sha256() if expected_sha256 else None
169172
buffer = bytearray()
170173

@@ -173,7 +176,7 @@ async def upload_part(data: bytes) -> None:
173176
part_num += 1
174177
nonce = crypto.derive_part_nonce(upload_id, part_num)
175178
data_len = len(data)
176-
data_md5 = hashlib.md5(data).hexdigest()
179+
data_md5 = hashlib.md5(data, usedforsecurity=False).hexdigest()
177180
ciphertext = crypto.encrypt(data, dek, nonce)
178181
cipher_len = len(ciphertext)
179182
del data # Free memory

s3proxy/request_handler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ async def handle_proxy_request(
9191
memory_limit = concurrency.get_memory_limit()
9292

9393
if memory_limit > 0 and needs_limit:
94-
content_length = int(request.headers.get("content-length", "0"))
94+
try:
95+
content_length = int(request.headers.get("content-length", "0"))
96+
except ValueError:
97+
content_length = 0
9598
memory_needed = concurrency.estimate_memory_footprint(method, content_length)
9699

97100
logger.info(

s3proxy/state/metadata.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,26 @@ def encode_multipart_metadata(meta: MultipartMetadata) -> str:
6666
return base64.b64encode(compressed).decode()
6767

6868

69+
70+
# Maximum decompressed metadata size (10 MB) — prevents gzip bombs
71+
MAX_METADATA_SIZE = 10 * 1024 * 1024
72+
73+
74+
def _safe_gzip_decompress(data: bytes, max_size: int = MAX_METADATA_SIZE) -> bytes:
75+
"""Decompress gzip data with a size limit to prevent decompression bombs."""
76+
with gzip.GzipFile(fileobj=__import__("io").BytesIO(data)) as f:
77+
result = f.read(max_size + 1)
78+
if len(result) > max_size:
79+
raise ValueError(
80+
f"Decompressed metadata exceeds {max_size} bytes limit"
81+
)
82+
return result
83+
84+
6985
def decode_multipart_metadata(encoded: str) -> MultipartMetadata:
7086
"""Decode metadata from base64-compressed JSON."""
7187
compressed = base64.b64decode(encoded)
72-
json_bytes = gzip.decompress(compressed)
88+
json_bytes = _safe_gzip_decompress(compressed)
7389
data = json_loads(json_bytes)
7490

7591
return MultipartMetadata(

s3proxy/streaming/chunked.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,26 @@
1313
# Streaming chunk size for reads/writes
1414
STREAM_CHUNK_SIZE = 64 * 1024 # 64KB chunks for streaming
1515

16+
# Safety limits for chunked decoding
17+
_MAX_CHUNK_HEADER_SIZE = 4096 # Max header line (hex size + signature)
18+
_MAX_CHUNK_SIZE = 64 * 1024 * 1024 # 64 MB max per chunk
19+
_MAX_BUFFER_SIZE = 66 * 1024 * 1024 # Slightly above max chunk to hold chunk + framing
20+
21+
22+
def _parse_chunk_size(header: bytes) -> int:
23+
"""Parse and validate chunk size from header bytes."""
24+
size_str = header.split(b";")[0].strip()
25+
if not size_str:
26+
raise ValueError("Empty chunk size")
27+
chunk_size = int(size_str, 16)
28+
if chunk_size < 0:
29+
raise ValueError(f"Negative chunk size: {chunk_size}")
30+
if chunk_size > _MAX_CHUNK_SIZE:
31+
raise ValueError(
32+
f"Chunk size {chunk_size} exceeds maximum {_MAX_CHUNK_SIZE}"
33+
)
34+
return chunk_size
35+
1636

1737
def decode_aws_chunked(body: bytes) -> bytes:
1838
"""Decode aws-chunked transfer encoding from buffered body.
@@ -22,25 +42,27 @@ def decode_aws_chunked(body: bytes) -> bytes:
2242
2343
Returns:
2444
Decoded bytes without chunk headers
45+
46+
Raises:
47+
ValueError: If chunked encoding is malformed or truncated.
2548
"""
2649
result = bytearray()
2750
pos = 0
2851
while pos < len(body):
2952
header_end = body.find(b"\r\n", pos)
3053
if header_end == -1:
31-
break
54+
raise ValueError("Truncated chunk: missing header terminator")
3255
header = body[pos:header_end]
33-
size_str = header.split(b";")[0]
34-
try:
35-
chunk_size = int(size_str, 16)
36-
except ValueError:
37-
break
56+
chunk_size = _parse_chunk_size(header)
3857
if chunk_size == 0:
3958
break
4059
data_start = header_end + 2
4160
data_end = data_start + chunk_size
4261
if data_end > len(body):
43-
break
62+
raise ValueError(
63+
f"Truncated chunk: expected {chunk_size} bytes, "
64+
f"only {len(body) - data_start} available"
65+
)
4466
result.extend(body[data_start:data_end])
4567
pos = data_end + 2
4668
return bytes(result)
@@ -59,23 +81,32 @@ async def decode_aws_chunked_stream(
5981
6082
Yields:
6183
Decoded data chunks
84+
85+
Raises:
86+
ValueError: If buffer exceeds safety limits or encoding is malformed.
6287
"""
6388
buffer = bytearray()
6489

6590
async for raw_chunk in request.stream():
6691
buffer.extend(raw_chunk)
6792

93+
if len(buffer) > _MAX_BUFFER_SIZE:
94+
raise ValueError(
95+
f"Chunked decode buffer ({len(buffer)} bytes) exceeds "
96+
f"maximum ({_MAX_BUFFER_SIZE} bytes)"
97+
)
98+
6899
while True:
69100
header_end = buffer.find(b"\r\n")
70101
if header_end == -1:
102+
if len(buffer) > _MAX_CHUNK_HEADER_SIZE:
103+
raise ValueError(
104+
f"Chunk header exceeds {_MAX_CHUNK_HEADER_SIZE} bytes"
105+
)
71106
break
72107

73108
header = buffer[:header_end]
74-
size_str = header.split(b";")[0]
75-
try:
76-
chunk_size = int(size_str, 16)
77-
except ValueError:
78-
break
109+
chunk_size = _parse_chunk_size(header)
79110

80111
if chunk_size == 0:
81112
return

0 commit comments

Comments
 (0)