From 9c6d49d68bfd2c4632555f69b4d9929ed29bdeb2 Mon Sep 17 00:00:00 2001 From: G-ELM Date: Fri, 26 Jun 2026 17:13:54 +0100 Subject: [PATCH 1/2] feat: add retry backoff for transient failures --- src/shade/http.py | 153 ++++++++++++++++++++++++++++++++++++--- tests/test_rate_limit.py | 63 +++++++++++++++- 2 files changed, 203 insertions(+), 13 deletions(-) diff --git a/src/shade/http.py b/src/shade/http.py index f3284c9..1a7e7f2 100644 --- a/src/shade/http.py +++ b/src/shade/http.py @@ -10,14 +10,30 @@ from __future__ import annotations import json +import logging import math +import random import time import urllib.error import urllib.parse import urllib.request from typing import Any, Dict, Optional, Tuple -from .errors import HTTPError, RateLimitError +from .errors import ( + AuthenticationError, + HTTPError, + InvalidRequestError, + NetworkError, + NotFoundError, + RateLimitError, +) + +logger = logging.getLogger(__name__) + +try: # pragma: no cover - optional dependency + import httpx +except ImportError: # pragma: no cover - optional dependency + httpx = None # --------------------------------------------------------------------------- # Constants @@ -62,6 +78,55 @@ def _backoff_seconds(attempt: int) -> float: return min(_BASE_BACKOFF * math.pow(2, attempt), _MAX_BACKOFF) +def _retry_delay(attempt: int, base_delay: float) -> float: + """Return a capped exponential delay with randomized jitter.""" + return min(base_delay * (2**attempt) + random.uniform(0, 0.5), _MAX_BACKOFF) + + +def _is_retryable_transport_error(exc: Exception) -> bool: + """Return True for transient network failures that should be retried.""" + if httpx is not None and isinstance(exc, (httpx.ConnectError, httpx.TimeoutException)): + return True + if isinstance(exc, (ConnectionResetError, TimeoutError, urllib.error.URLError)): + return True + return False + + +def _is_retryable_status(status: int) -> bool: + return status in {502, 503, 504} + + +def _retry_with_backoff(fn, max_retries: int, base_delay: float): + """Execute *fn* and retry transient failures with exponential back-off.""" + for attempt in range(max_retries + 1): + try: + return fn() + except Exception as exc: + if attempt >= max_retries or not _is_retryable_error(exc): + raise + delay = _retry_delay(attempt, base_delay) + logger.debug( + "Retrying request after transient failure (attempt %s/%s) in %.3fs", + attempt + 1, + max_retries + 1, + delay, + ) + time.sleep(delay) + + +def _is_retryable_error(exc: Exception) -> bool: + if _is_retryable_transport_error(exc): + return True + + if httpx is not None and isinstance(exc, httpx.HTTPStatusError): + return _is_retryable_status(exc.response.status_code) + + if isinstance(exc, HTTPError): + return _is_retryable_status(exc.status_code or 0) + + return False + + def _raise_for_status( status: int, headers: Any, @@ -81,6 +146,14 @@ def _raise_for_status( ------ RateLimitError If HTTP 429 and retries are exhausted (or auto-retry is off). + InvalidRequestError + For HTTP 400 responses. + AuthenticationError + For HTTP 401/403 responses. + NotFoundError + For HTTP 404 responses. + NetworkError + For transient 502/503/504 responses after retries are exhausted. HTTPError For any other non-2xx status. """ @@ -100,6 +173,27 @@ def _raise_for_status( msg = f"Rate limit exceeded. {detail}".strip() raise RateLimitError(msg, retry_after=retry_after) + if status == 400: + raise InvalidRequestError("Invalid request", status_code=status) + + if status in {401, 403}: + raise AuthenticationError("Authentication failed", status_code=status) + + if status == 404: + raise NotFoundError("Resource not found", status_code=status) + + if status in {502, 503, 504}: + if attempt < max_retries: + wait = _retry_delay(attempt, _BASE_BACKOFF) + logger.debug( + "Retrying request after server error (attempt %s/%s) in %.3fs", + attempt + 1, + max_retries + 1, + wait, + ) + return wait + raise NetworkError(f"Request failed with transient server error: {status}", status_code=status) + try: detail = json.loads(body).get("error", {}).get("message", "") except Exception: @@ -176,12 +270,29 @@ def request( attempt = 0 while True: req = self._build_request(method, path, payload) - status, headers, body = self._execute(req) + try: + status, headers, body = self._execute(req) + except Exception as exc: + if _is_retryable_transport_error(exc): + if attempt >= self.max_retries: + raise NetworkError( + "Request failed after exhausting retries", + status_code=None, + ) from exc + delay = _retry_delay(attempt, _BASE_BACKOFF) + logger.debug( + "Retrying request after transient failure (attempt %s/%s) in %.3fs", + attempt + 1, + self.max_retries + 1, + delay, + ) + time.sleep(delay) + attempt += 1 + continue + raise wait = _raise_for_status(status, headers, body, attempt, self.max_retries) if wait is None: - # success return json.loads(body) if body else {} - # 429 — sleep and retry time.sleep(wait) attempt += 1 @@ -269,18 +380,36 @@ async def request( ) as session: while True: url = f"{url_base}/{path.lstrip('/')}" - resp = await session.request( - method.upper(), - url, - json=payload, - headers=headers, - ) - body = await resp.read() + try: + resp = await session.request( + method.upper(), + url, + json=payload, + headers=headers, + ) + body = await resp.read() + except Exception as exc: + if _is_retryable_transport_error(exc): + if attempt >= self.max_retries: + raise NetworkError( + "Request failed after exhausting retries", + status_code=None, + ) from exc + delay = _retry_delay(attempt, _BASE_BACKOFF) + logger.debug( + "Retrying request after transient failure (attempt %s/%s) in %.3fs", + attempt + 1, + self.max_retries + 1, + delay, + ) + await asyncio.sleep(delay) + attempt += 1 + continue + raise wait = _raise_for_status( resp.status, resp.headers, body, attempt, self.max_retries ) if wait is None: return json.loads(body) if body else {} - # 429 — non-blocking sleep await asyncio.sleep(wait) attempt += 1 \ No newline at end of file diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py index 7d1174f..3349590 100644 --- a/tests/test_rate_limit.py +++ b/tests/test_rate_limit.py @@ -24,7 +24,7 @@ import pytest from shade import RateLimitError -from shade.errors import HTTPError +from shade.errors import HTTPError, InvalidRequestError, NetworkError from shade.http import ( DEFAULT_MAX_RETRIES, AsyncHTTPClient, @@ -267,6 +267,67 @@ def fake_execute(req): assert exc_info.value.retry_after is None + def test_retries_on_transient_5xx_then_succeeds(self): + client = self._client(max_retries=2) + responses = [ + (503, {}, b"{}"), + (503, {}, b"{}"), + (200, {}, _fake_200_body()), + ] + idx = 0 + + def fake_execute(req): + nonlocal idx + status, headers, body = responses[idx] + idx += 1 + return status, headers, body + + sleep_calls: List[float] = [] + with patch.object(client, "_execute", side_effect=fake_execute), \ + patch("time.sleep", side_effect=lambda s: sleep_calls.append(s)), \ + patch("shade.http.random.uniform", side_effect=[0.0, 0.0]): + result = client.request("GET", "/payments") + + assert result == {"id": "pay_123", "status": "ok"} + assert sleep_calls == [1.0, 2.0] + + def test_400_raises_invalid_request_error_immediately(self): + client = self._client(max_retries=3) + + def fake_execute(req): + return 400, {}, b'{"error": {"message": "bad request"}}' + + with patch.object(client, "_execute", side_effect=fake_execute), \ + patch("time.sleep") as mock_sleep: + with pytest.raises(InvalidRequestError) as exc_info: + client.request("GET", "/payments") + + mock_sleep.assert_not_called() + assert exc_info.value.status_code == 400 + + def test_retries_exhausted_raise_network_error(self): + client = self._client(max_retries=1) + responses = [ + (503, {}, b"{}"), + (503, {}, b"{}"), + (503, {}, b"{}"), + ] + idx = 0 + + def fake_execute(req): + nonlocal idx + status, headers, body = responses[idx] + idx += 1 + return status, headers, body + + with patch.object(client, "_execute", side_effect=fake_execute), \ + patch("time.sleep") as mock_sleep: + with pytest.raises(NetworkError) as exc_info: + client.request("GET", "/payments") + + assert exc_info.value.status_code == 503 + assert mock_sleep.call_count == 1 + # --------------------------------------------------------------------------- # AsyncHTTPClient From ef4062aff62e772340042179b42c982381a34888 Mon Sep 17 00:00:00 2001 From: G-ELM Date: Fri, 26 Jun 2026 18:17:51 +0100 Subject: [PATCH 2/2] feat: implement coderabbit review --- src/shade/http.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/shade/http.py b/src/shade/http.py index 1a7e7f2..b05ab26 100644 --- a/src/shade/http.py +++ b/src/shade/http.py @@ -87,6 +87,23 @@ def _is_retryable_transport_error(exc: Exception) -> bool: """Return True for transient network failures that should be retried.""" if httpx is not None and isinstance(exc, (httpx.ConnectError, httpx.TimeoutException)): return True + + try: + import aiohttp + except ImportError: + aiohttp = None + + if aiohttp is not None and isinstance( + exc, + ( + aiohttp.ClientConnectionError, + aiohttp.ClientConnectorError, + aiohttp.ClientOSError, + aiohttp.ServerDisconnectedError, + ), + ): + return True + if isinstance(exc, (ConnectionResetError, TimeoutError, urllib.error.URLError)): return True return False @@ -180,7 +197,12 @@ def _raise_for_status( raise AuthenticationError("Authentication failed", status_code=status) if status == 404: - raise NotFoundError("Resource not found", status_code=status) + response_body = body.decode("utf-8", errors="replace") + raise NotFoundError( + "Resource not found", + status_code=status, + response_body=response_body, + ) if status in {502, 503, 504}: if attempt < max_retries: