Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 163 additions & 12 deletions src/shade/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,30 @@
from __future__ import annotations

import json
import logging
import math
import random
import time
import urllib.error
import urllib.parse
import urllib.request
from typing import Any, Dict, Optional, Tuple

from .errors import HTTPError, RateLimitError
from .errors import (
AuthenticationError,
HTTPError,
InvalidRequestError,
NetworkError,
NotFoundError,
RateLimitError,
)

logger = logging.getLogger(__name__)

try: # pragma: no cover - optional dependency
import httpx
except ImportError: # pragma: no cover - optional dependency
httpx = None

# ---------------------------------------------------------------------------
# Constants
Expand Down Expand Up @@ -62,6 +78,72 @@ def _backoff_seconds(attempt: int) -> float:
return min(_BASE_BACKOFF * math.pow(2, attempt), _MAX_BACKOFF)


def _retry_delay(attempt: int, base_delay: float) -> float:
"""Return a capped exponential delay with randomized jitter."""
return min(base_delay * (2**attempt) + random.uniform(0, 0.5), _MAX_BACKOFF)


def _is_retryable_transport_error(exc: Exception) -> bool:
"""Return True for transient network failures that should be retried."""
if httpx is not None and isinstance(exc, (httpx.ConnectError, httpx.TimeoutException)):
return True

try:
import aiohttp
except ImportError:
aiohttp = None

if aiohttp is not None and isinstance(
exc,
(
aiohttp.ClientConnectionError,
aiohttp.ClientConnectorError,
aiohttp.ClientOSError,
aiohttp.ServerDisconnectedError,
),
):
return True

if isinstance(exc, (ConnectionResetError, TimeoutError, urllib.error.URLError)):
return True
return False
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def _is_retryable_status(status: int) -> bool:
return status in {502, 503, 504}


def _retry_with_backoff(fn, max_retries: int, base_delay: float):
"""Execute *fn* and retry transient failures with exponential back-off."""
for attempt in range(max_retries + 1):
try:
return fn()
except Exception as exc:
if attempt >= max_retries or not _is_retryable_error(exc):
raise
delay = _retry_delay(attempt, base_delay)
logger.debug(
"Retrying request after transient failure (attempt %s/%s) in %.3fs",
attempt + 1,
max_retries + 1,
delay,
)
time.sleep(delay)


def _is_retryable_error(exc: Exception) -> bool:
if _is_retryable_transport_error(exc):
return True

if httpx is not None and isinstance(exc, httpx.HTTPStatusError):
return _is_retryable_status(exc.response.status_code)

if isinstance(exc, HTTPError):
return _is_retryable_status(exc.status_code or 0)

return False


def _raise_for_status(
status: int,
headers: Any,
Expand All @@ -81,6 +163,14 @@ def _raise_for_status(
------
RateLimitError
If HTTP 429 and retries are exhausted (or auto-retry is off).
InvalidRequestError
For HTTP 400 responses.
AuthenticationError
For HTTP 401/403 responses.
NotFoundError
For HTTP 404 responses.
NetworkError
For transient 502/503/504 responses after retries are exhausted.
HTTPError
For any other non-2xx status.
"""
Expand All @@ -100,6 +190,32 @@ def _raise_for_status(
msg = f"Rate limit exceeded. {detail}".strip()
raise RateLimitError(msg, retry_after=retry_after)

if status == 400:
raise InvalidRequestError("Invalid request", status_code=status)

if status in {401, 403}:
raise AuthenticationError("Authentication failed", status_code=status)

if status == 404:
response_body = body.decode("utf-8", errors="replace")
raise NotFoundError(
"Resource not found",
status_code=status,
response_body=response_body,
)

if status in {502, 503, 504}:
if attempt < max_retries:
wait = _retry_delay(attempt, _BASE_BACKOFF)
logger.debug(
"Retrying request after server error (attempt %s/%s) in %.3fs",
attempt + 1,
max_retries + 1,
wait,
)
return wait
raise NetworkError(f"Request failed with transient server error: {status}", status_code=status)

try:
detail = json.loads(body).get("error", {}).get("message", "")
except Exception:
Expand Down Expand Up @@ -176,12 +292,29 @@ def request(
attempt = 0
while True:
req = self._build_request(method, path, payload)
status, headers, body = self._execute(req)
try:
status, headers, body = self._execute(req)
except Exception as exc:
if _is_retryable_transport_error(exc):
if attempt >= self.max_retries:
raise NetworkError(
"Request failed after exhausting retries",
status_code=None,
) from exc
delay = _retry_delay(attempt, _BASE_BACKOFF)
logger.debug(
"Retrying request after transient failure (attempt %s/%s) in %.3fs",
attempt + 1,
self.max_retries + 1,
delay,
)
time.sleep(delay)
attempt += 1
continue
raise
wait = _raise_for_status(status, headers, body, attempt, self.max_retries)
if wait is None:
# success
return json.loads(body) if body else {}
# 429 — sleep and retry
time.sleep(wait)
attempt += 1

Expand Down Expand Up @@ -269,18 +402,36 @@ async def request(
) as session:
while True:
url = f"{url_base}/{path.lstrip('/')}"
resp = await session.request(
method.upper(),
url,
json=payload,
headers=headers,
)
body = await resp.read()
try:
resp = await session.request(
method.upper(),
url,
json=payload,
headers=headers,
)
body = await resp.read()
except Exception as exc:
if _is_retryable_transport_error(exc):
if attempt >= self.max_retries:
raise NetworkError(
"Request failed after exhausting retries",
status_code=None,
) from exc
delay = _retry_delay(attempt, _BASE_BACKOFF)
logger.debug(
"Retrying request after transient failure (attempt %s/%s) in %.3fs",
attempt + 1,
self.max_retries + 1,
delay,
)
await asyncio.sleep(delay)
attempt += 1
continue
raise
wait = _raise_for_status(
resp.status, resp.headers, body, attempt, self.max_retries
)
if wait is None:
return json.loads(body) if body else {}
# 429 — non-blocking sleep
await asyncio.sleep(wait)
attempt += 1
63 changes: 62 additions & 1 deletion tests/test_rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pytest

from shade import RateLimitError
from shade.errors import HTTPError
from shade.errors import HTTPError, InvalidRequestError, NetworkError
from shade.http import (
DEFAULT_MAX_RETRIES,
AsyncHTTPClient,
Expand Down Expand Up @@ -267,6 +267,67 @@ def fake_execute(req):

assert exc_info.value.retry_after is None

def test_retries_on_transient_5xx_then_succeeds(self):
client = self._client(max_retries=2)
responses = [
(503, {}, b"{}"),
(503, {}, b"{}"),
(200, {}, _fake_200_body()),
]
idx = 0

def fake_execute(req):
nonlocal idx
status, headers, body = responses[idx]
idx += 1
return status, headers, body

sleep_calls: List[float] = []
with patch.object(client, "_execute", side_effect=fake_execute), \
patch("time.sleep", side_effect=lambda s: sleep_calls.append(s)), \
patch("shade.http.random.uniform", side_effect=[0.0, 0.0]):
result = client.request("GET", "/payments")

assert result == {"id": "pay_123", "status": "ok"}
assert sleep_calls == [1.0, 2.0]

def test_400_raises_invalid_request_error_immediately(self):
client = self._client(max_retries=3)

def fake_execute(req):
return 400, {}, b'{"error": {"message": "bad request"}}'

with patch.object(client, "_execute", side_effect=fake_execute), \
patch("time.sleep") as mock_sleep:
with pytest.raises(InvalidRequestError) as exc_info:
client.request("GET", "/payments")

mock_sleep.assert_not_called()
assert exc_info.value.status_code == 400

def test_retries_exhausted_raise_network_error(self):
client = self._client(max_retries=1)
responses = [
(503, {}, b"{}"),
(503, {}, b"{}"),
(503, {}, b"{}"),
]
idx = 0

def fake_execute(req):
nonlocal idx
status, headers, body = responses[idx]
idx += 1
return status, headers, body

with patch.object(client, "_execute", side_effect=fake_execute), \
patch("time.sleep") as mock_sleep:
with pytest.raises(NetworkError) as exc_info:
client.request("GET", "/payments")

assert exc_info.value.status_code == 503
assert mock_sleep.call_count == 1


# ---------------------------------------------------------------------------
# AsyncHTTPClient
Expand Down
Loading