Skip to content

Commit 52e1488

Browse files
fix: harden error handling, concurrency, and XML safety
- Escape XML in error responses to prevent injection via error messages - Add retry limit (MAX_WATCH_RETRIES=5) to Redis WATCH/MULTI/EXEC loops to prevent unbounded recursion under high contention - Refactor concurrency limiter from module-level globals to a ConcurrencyLimiter class with proper encapsulation; module-level functions delegate to a default instance for backward compatibility - Clean up metrics import (remove noqa suppression, use real name)
1 parent 97eb9d4 commit 52e1488

3 files changed

Lines changed: 119 additions & 85 deletions

File tree

s3proxy/app.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
import uuid
99
from collections.abc import AsyncIterator
1010
from contextlib import asynccontextmanager
11+
from xml.sax.saxutils import escape as xml_escape
1112

1213
import structlog
1314
from fastapi import FastAPI, HTTPException, Request, Response
1415
from fastapi.responses import PlainTextResponse
1516
from prometheus_client import CONTENT_TYPE_LATEST, generate_latest
1617
from structlog.stdlib import BoundLogger
1718

18-
from . import metrics as _ # noqa: F401 - Import to register metrics
19+
from . import metrics # Ensure Prometheus collectors are registered at import time
1920
from .config import Settings
2021
from .errors import S3Error, get_s3_error_code
2122
from .handlers import S3ProxyHandler
@@ -130,8 +131,8 @@ async def s3_exception_handler(request: Request, exc: HTTPException):
130131

131132
error_xml = f"""<?xml version="1.0" encoding="UTF-8"?>
132133
<Error>
133-
<Code>{error_code}</Code>
134-
<Message>{message}</Message>
134+
<Code>{xml_escape(error_code)}</Code>
135+
<Message>{xml_escape(str(message))}</Message>
135136
<RequestId>{request_id}</RequestId>
136137
</Error>"""
137138
return Response(
@@ -161,7 +162,7 @@ async def metrics():
161162
"/{path:path}",
162163
methods=["GET", "PUT", "POST", "DELETE", "HEAD"],
163164
)
164-
async def proxy(request: Request, path: str): # noqa: ARG001
165+
async def proxy(request: Request, path: str): # noqa: ARG001 - required by FastAPI for {path:path}
165166
return await handle_proxy_request(
166167
request, request.app.state.handler, request.app.state.verifier
167168
)

s3proxy/concurrency.py

Lines changed: 99 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,6 @@
2020
MIN_RESERVATION = 64 * 1024 # 64KB minimum per request
2121
MAX_BUFFER_SIZE = 8 * 1024 * 1024 # 8MB streaming buffer size
2222

23-
# Module-level state
24-
_limit_mb = int(os.environ.get("S3PROXY_MEMORY_LIMIT_MB", "64"))
25-
_limit_bytes = _limit_mb * 1024 * 1024
26-
_active_bytes = 0
27-
_lock: asyncio.Lock | None = None
28-
29-
# Initialize memory limit metric
30-
MEMORY_LIMIT_BYTES.set(_limit_bytes)
31-
3223

3324
def _create_malloc_release() -> Callable[[], int] | None:
3425
"""Create platform-specific function to release memory back to OS.
@@ -50,19 +41,89 @@ def _create_malloc_release() -> Callable[[], int] | None:
5041
_malloc_release = _create_malloc_release()
5142

5243

53-
async def _get_lock() -> asyncio.Lock:
54-
global _lock
55-
if _lock is None:
56-
_lock = asyncio.Lock()
57-
return _lock
58-
59-
60-
def get_memory_limit() -> int:
61-
return _limit_bytes
44+
class ConcurrencyLimiter:
45+
"""Memory-based concurrency limiter.
6246
47+
Tracks reserved memory across concurrent requests and rejects new requests
48+
when the configured limit would be exceeded.
49+
"""
6350

64-
def get_active_memory() -> int:
65-
return _active_bytes
51+
def __init__(self, limit_mb: int = 64) -> None:
52+
self._limit_mb = limit_mb
53+
self._limit_bytes = limit_mb * 1024 * 1024
54+
self._active_bytes = 0
55+
self._lock = asyncio.Lock()
56+
MEMORY_LIMIT_BYTES.set(self._limit_bytes)
57+
58+
@property
59+
def limit_bytes(self) -> int:
60+
return self._limit_bytes
61+
62+
@property
63+
def active_bytes(self) -> int:
64+
return self._active_bytes
65+
66+
@active_bytes.setter
67+
def active_bytes(self, value: int) -> None:
68+
"""Set active memory (testing only)."""
69+
self._active_bytes = value
70+
71+
def set_memory_limit(self, limit_mb: int) -> None:
72+
"""Update the memory limit."""
73+
self._limit_mb = limit_mb
74+
self._limit_bytes = limit_mb * 1024 * 1024
75+
MEMORY_LIMIT_BYTES.set(self._limit_bytes)
76+
77+
async def try_acquire(self, bytes_needed: int) -> int:
78+
"""Reserve memory. Returns bytes reserved. Raises S3Error.slow_down if exhausted."""
79+
if self._limit_bytes <= 0:
80+
return 0
81+
82+
to_reserve = max(MIN_RESERVATION, min(bytes_needed, self._limit_bytes))
83+
84+
async with self._lock:
85+
if self._active_bytes + to_reserve > self._limit_bytes:
86+
active_mb = self._active_bytes / 1024 / 1024
87+
request_mb = to_reserve / 1024 / 1024
88+
limit_mb = self._limit_bytes / 1024 / 1024
89+
logger.warning("MEMORY_REJECTED", active_mb=round(active_mb, 2),
90+
requested_mb=round(request_mb, 2), limit_mb=round(limit_mb, 2))
91+
MEMORY_REJECTIONS.inc()
92+
raise S3Error.slow_down(
93+
f"Memory limit: {active_mb:.0f}MB + {request_mb:.0f}MB > {limit_mb:.0f}MB"
94+
)
95+
self._active_bytes += to_reserve
96+
MEMORY_RESERVED_BYTES.set(self._active_bytes)
97+
return to_reserve
98+
99+
async def release(self, bytes_reserved: int) -> None:
100+
"""Release reserved memory and trigger OS memory release."""
101+
if self._limit_bytes <= 0 or bytes_reserved <= 0:
102+
return
103+
104+
async with self._lock:
105+
self._active_bytes = max(0, self._active_bytes - bytes_reserved)
106+
MEMORY_RESERVED_BYTES.set(self._active_bytes)
107+
108+
# Run garbage collection and release memory to OS
109+
gc.collect(0)
110+
gc.collect(1)
111+
gc.collect(2)
112+
113+
if _malloc_release:
114+
try:
115+
_malloc_release()
116+
except OSError:
117+
pass
118+
119+
# Yield to allow OS memory reclaim
120+
await asyncio.sleep(0)
121+
122+
123+
# Default instance used by module-level functions
124+
_default = ConcurrencyLimiter(
125+
limit_mb=int(os.environ.get("S3PROXY_MEMORY_LIMIT_MB", "64"))
126+
)
66127

67128

68129
def estimate_memory_footprint(method: str, content_length: int) -> int:
@@ -78,75 +139,38 @@ def estimate_memory_footprint(method: str, content_length: int) -> int:
78139
return MAX_BUFFER_SIZE
79140

80141

81-
async def try_acquire_memory(bytes_needed: int) -> int:
82-
"""Reserve memory. Returns bytes reserved. Raises S3Error.slow_down if exhausted."""
83-
global _active_bytes
142+
# Module-level convenience functions delegating to the default instance
84143

85-
if _limit_bytes <= 0:
86-
return 0
87-
88-
to_reserve = max(MIN_RESERVATION, min(bytes_needed, _limit_bytes))
89-
90-
lock = await _get_lock()
91-
async with lock:
92-
if _active_bytes + to_reserve > _limit_bytes:
93-
active_mb = _active_bytes / 1024 / 1024
94-
request_mb = to_reserve / 1024 / 1024
95-
limit_mb = _limit_bytes / 1024 / 1024
96-
logger.warning("MEMORY_REJECTED", active_mb=round(active_mb, 2),
97-
requested_mb=round(request_mb, 2), limit_mb=round(limit_mb, 2))
98-
MEMORY_REJECTIONS.inc()
99-
raise S3Error.slow_down(
100-
f"Memory limit: {active_mb:.0f}MB + {request_mb:.0f}MB > {limit_mb:.0f}MB"
101-
)
102-
_active_bytes += to_reserve
103-
MEMORY_RESERVED_BYTES.set(_active_bytes)
104-
return to_reserve
105144

145+
def get_memory_limit() -> int:
146+
return _default.limit_bytes
106147

107-
async def release_memory(bytes_reserved: int) -> None:
108-
"""Release reserved memory and trigger OS memory release."""
109-
global _active_bytes
110148

111-
if _limit_bytes <= 0 or bytes_reserved <= 0:
112-
return
149+
def get_active_memory() -> int:
150+
return _default.active_bytes
113151

114-
lock = await _get_lock()
115-
async with lock:
116-
_active_bytes = max(0, _active_bytes - bytes_reserved)
117-
MEMORY_RESERVED_BYTES.set(_active_bytes)
118152

119-
# Run garbage collection and release memory to OS
120-
gc.collect(0)
121-
gc.collect(1)
122-
gc.collect(2)
153+
async def try_acquire_memory(bytes_needed: int) -> int:
154+
return await _default.try_acquire(bytes_needed)
123155

124-
if _malloc_release:
125-
try:
126-
_malloc_release()
127-
except OSError:
128-
pass
129156

130-
# Yield to allow OS memory reclaim
131-
await asyncio.sleep(0)
157+
async def release_memory(bytes_reserved: int) -> None:
158+
await _default.release(bytes_reserved)
132159

133160

134161
def reset_state() -> None:
135-
"""Reset state (testing only)."""
136-
global _active_bytes, _lock
137-
_active_bytes = 0
138-
_lock = None
162+
"""Reset default instance state (testing only)."""
163+
global _default
164+
_default = ConcurrencyLimiter(limit_mb=_default._limit_mb)
165+
# Reset reserved bytes metric to 0 for clean test state
166+
MEMORY_RESERVED_BYTES.set(0)
139167

140168

141169
def set_memory_limit(limit_mb: int) -> None:
142-
"""Set memory limit (testing only)."""
143-
global _limit_mb, _limit_bytes
144-
_limit_mb = limit_mb
145-
_limit_bytes = limit_mb * 1024 * 1024
146-
MEMORY_LIMIT_BYTES.set(_limit_bytes)
170+
"""Set memory limit on default instance (testing only)."""
171+
_default.set_memory_limit(limit_mb)
147172

148173

149174
def set_active_memory(bytes_val: int) -> None:
150-
"""Set active memory (testing only)."""
151-
global _active_bytes
152-
_active_bytes = bytes_val
175+
"""Set active memory on default instance (testing only)."""
176+
_default.active_bytes = bytes_val

s3proxy/state/storage.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
# Type alias for updater function: takes bytes, returns bytes
2222
Updater = Callable[[bytes], bytes]
2323

24+
# Maximum retries for Redis optimistic locking (WATCH/MULTI/EXEC)
25+
MAX_WATCH_RETRIES = 5
26+
2427

2528
class StateStore(ABC):
2629
"""Abstract interface for state storage backends."""
@@ -118,7 +121,7 @@ async def set(self, key: str, value: bytes, ttl_seconds: int) -> None:
118121
async def delete(self, key: str) -> None:
119122
await self._client.delete(self._key(key))
120123

121-
async def get_and_delete(self, key: str) -> bytes | None:
124+
async def get_and_delete(self, key: str, _retries: int = 0) -> bytes | None:
122125
"""Atomically get and delete using Redis transaction."""
123126
import redis.asyncio as redis
124127

@@ -137,11 +140,14 @@ async def get_and_delete(self, key: str) -> bytes | None:
137140
return data
138141

139142
except redis.WatchError:
140-
# Retry on conflict
141-
return await self.get_and_delete(key)
143+
if _retries >= MAX_WATCH_RETRIES:
144+
logger.error("REDIS_WATCH_RETRIES_EXHAUSTED", key=key, operation="get_and_delete")
145+
raise
146+
logger.debug("REDIS_WATCH_RETRY", key=key, attempt=_retries + 1)
147+
return await self.get_and_delete(key, _retries=_retries + 1)
142148

143149
async def update(
144-
self, key: str, updater: Updater, ttl_seconds: int
150+
self, key: str, updater: Updater, ttl_seconds: int, _retries: int = 0
145151
) -> bytes | None:
146152
"""Atomically update using Redis WATCH/MULTI/EXEC."""
147153
import redis.asyncio as redis
@@ -163,5 +169,8 @@ async def update(
163169
return new_data
164170

165171
except redis.WatchError:
166-
logger.debug("REDIS_WATCH_RETRY", key=key)
167-
return await self.update(key, updater, ttl_seconds)
172+
if _retries >= MAX_WATCH_RETRIES:
173+
logger.error("REDIS_WATCH_RETRIES_EXHAUSTED", key=key, operation="update")
174+
raise
175+
logger.debug("REDIS_WATCH_RETRY", key=key, attempt=_retries + 1)
176+
return await self.update(key, updater, ttl_seconds, _retries=_retries + 1)

0 commit comments

Comments
 (0)