Skip to content

Commit 5815d12

Browse files
committed
feat(async-oauth-client): add retry mechanism to AsyncOauthClient
1 parent 69eecb5 commit 5815d12

1 file changed

Lines changed: 49 additions & 6 deletions

File tree

openapi_python_sdk/async_oauth_client.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import asyncio
12
import base64
3+
import random
24
from typing import Any, Dict, List
35

46
import httpx
@@ -12,8 +14,23 @@ class AsyncOauthClient:
1214
Suitable for use with FastAPI, aiohttp, etc.
1315
"""
1416

15-
def __init__(self, username: str, apikey: str, test: bool = False, client: Any = None, timeout: float = 30.0):
17+
def __init__(
18+
self,
19+
username: str,
20+
apikey: str,
21+
test: bool = False,
22+
client: Any = None,
23+
timeout: float = 30.0,
24+
max_retries: int = 0,
25+
backoff_factor: float = 1.0,
26+
retry_on_status: List[int] = None,
27+
):
1628
self.client = client if client is not None else httpx.AsyncClient(timeout=timeout)
29+
self.max_retries = max_retries
30+
self.backoff_factor = backoff_factor
31+
self.retry_on_status = (
32+
retry_on_status if retry_on_status is not None else [429, 502, 503, 504]
33+
)
1734
self.url: str = TEST_OAUTH_BASE_URL if test else OAUTH_BASE_URL
1835
self.auth_header: str = (
1936
"Basic " + base64.b64encode(f"{username}:{apikey}".encode("utf-8")).decode()
@@ -35,35 +52,61 @@ async def aclose(self):
3552
"""Manually close the underlying HTTP client (async)."""
3653
await self.client.aclose()
3754

55+
async def _request_with_retry(self, request_fn, *args, **kwargs) -> httpx.Response:
56+
attempts = 0
57+
while True:
58+
try:
59+
resp = await request_fn(*args, **kwargs)
60+
if resp.status_code in self.retry_on_status and attempts < self.max_retries:
61+
attempts += 1
62+
sleep_time = self.backoff_factor * (2 ** attempts) + random.uniform(0, 0.5)
63+
if resp.status_code == 429:
64+
retry_after = resp.headers.get("Retry-After")
65+
if retry_after:
66+
try:
67+
sleep_time = float(retry_after)
68+
except ValueError:
69+
pass
70+
await asyncio.sleep(sleep_time)
71+
continue
72+
return resp
73+
except httpx.RequestError as exc:
74+
if attempts < self.max_retries:
75+
attempts += 1
76+
sleep_time = self.backoff_factor * (2 ** attempts) + random.uniform(0, 0.5)
77+
await asyncio.sleep(sleep_time)
78+
continue
79+
raise exc
80+
3881
async def get_scopes(self, limit: bool = False) -> Dict[str, Any]:
3982
"""Retrieve available scopes for the current user (async)."""
4083
params = {"limit": int(limit)}
4184
url = f"{self.url}/scopes"
42-
resp = await self.client.get(url=url, headers=self.headers, params=params)
85+
resp = await self._request_with_retry(self.client.get, url=url, headers=self.headers, params=params)
4386
return resp.json()
4487

4588
async def create_token(self, scopes: List[str] = [], ttl: int = 0) -> Dict[str, Any]:
4689
"""Create a new bearer token with specified scopes and TTL (async)."""
4790
payload = {"scopes": scopes, "ttl": ttl}
4891
url = f"{self.url}/token"
49-
resp = await self.client.post(url=url, headers=self.headers, json=payload)
92+
resp = await self._request_with_retry(self.client.post, url=url, headers=self.headers, json=payload)
5093
return resp.json()
5194

5295
async def get_token(self, scope: str = None) -> Dict[str, Any]:
5396
"""Retrieve an existing token, optionally filtered by scope (async)."""
5497
params = {"scope": scope or ""}
5598
url = f"{self.url}/token"
56-
resp = await self.client.get(url=url, headers=self.headers, params=params)
99+
resp = await self._request_with_retry(self.client.get, url=url, headers=self.headers, params=params)
57100
return resp.json()
58101

59102
async def delete_token(self, id: str) -> Dict[str, Any]:
60103
"""Revoke/Delete a specific token by ID (async)."""
61104
url = f"{self.url}/token/{id}"
62-
resp = await self.client.delete(url=url, headers=self.headers)
105+
resp = await self._request_with_retry(self.client.delete, url=url, headers=self.headers)
63106
return resp.json()
64107

65108
async def get_counters(self, period: str, date: str) -> Dict[str, Any]:
66109
"""Retrieve usage counters for a specific period and date (async)."""
67110
url = f"{self.url}/counters/{period}/{date}"
68-
resp = await self.client.get(url=url, headers=self.headers)
111+
resp = await self._request_with_retry(self.client.get, url=url, headers=self.headers)
69112
return resp.json()

0 commit comments

Comments
 (0)