Skip to content
Open
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,26 @@ resp = client.request(
)
```

### Customizing the Transport Layer

If you need to configure custom retry logic, proxies, or use a different HTTP client (such as passing a `requests.Session` with a custom urllib3 `Retry`), you can inject it directly using the `client` parameter on any SDK class:

```python
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from openapi_python_sdk import Client
import requests

retry = Retry(total=3)
adapter = HTTPAdapter(max_retries=retry)

session = requests.Session()
session.mount("https://", adapter)

# Pass the custom session to the Client explicitly
client = Client("token", client=session)
```

## Async Usage

The SDK provides `AsyncClient` and `AsyncOauthClient` for use with asynchronous frameworks like FastAPI or `aiohttp`.
Expand Down
11 changes: 9 additions & 2 deletions openapi_python_sdk/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class AsyncClient:
Suitable for use with FastAPI, aiohttp, etc.
"""

def __init__(self, token: str):
self.client = httpx.AsyncClient()
def __init__(self, token: str, client: Any = None):
self.client = client if client is not None else httpx.AsyncClient()
self.auth_header: str = f"Bearer {token}"
self.headers: Dict[str, str] = {
"Authorization": self.auth_header,
Expand Down Expand Up @@ -43,6 +43,13 @@ async def request(
payload = payload or {}
params = params or {}
url = url or ""

if params:
import urllib.parse
query_string = urllib.parse.urlencode(params, doseq=True)
url = f"{url}&{query_string}" if "?" in url else f"{url}?{query_string}"
params = None

resp = await self.client.request(
method=method,
url=url,
Expand Down
4 changes: 2 additions & 2 deletions openapi_python_sdk/async_oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class AsyncOauthClient:
Suitable for use with FastAPI, aiohttp, etc.
"""

def __init__(self, username: str, apikey: str, test: bool = False):
self.client = httpx.AsyncClient()
def __init__(self, username: str, apikey: str, test: bool = False, client: Any = None):
self.client = client if client is not None else httpx.AsyncClient()
self.url: str = TEST_OAUTH_BASE_URL if test else OAUTH_BASE_URL
self.auth_header: str = (
"Basic " + base64.b64encode(f"{username}:{apikey}".encode("utf-8")).decode()
Expand Down
31 changes: 29 additions & 2 deletions openapi_python_sdk/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import threading
from typing import Any, Dict

import httpx
Expand All @@ -14,14 +15,33 @@ class Client:
Synchronous client for making authenticated requests to Openapi endpoints.
"""

def __init__(self, token: str):
self.client = httpx.Client()
def __init__(self, token: str, client: Any = None):
self._client = client
self._thread_local = threading.local()
self.auth_header: str = f"Bearer {token}"
self.headers: Dict[str, str] = {
"Authorization": self.auth_header,
"Content-Type": "application/json",
}

@property
def client(self) -> Any:
"""
Thread-safe access to the underlying HTTP client.
If a custom client was provided at initialization, it is returned.
Otherwise, a thread-local httpx.Client is created and returned.
"""
if self._client is not None:
return self._client

if not hasattr(self._thread_local, "client"):
self._thread_local.client = httpx.Client()
return self._thread_local.client

@client.setter
def client(self, value: Any):
self._client = value

def __enter__(self):
"""Enable use as a synchronous context manager."""
return self
Expand All @@ -47,6 +67,13 @@ def request(
payload = payload or {}
params = params or {}
url = url or ""

if params:
import urllib.parse
query_string = urllib.parse.urlencode(params, doseq=True)
url = f"{url}&{query_string}" if "?" in url else f"{url}?{query_string}"
params = None

data = self.client.request(
method=method,
url=url,
Expand Down
23 changes: 21 additions & 2 deletions openapi_python_sdk/oauth_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import threading
from typing import Any, Dict, List

import httpx
Expand All @@ -12,8 +13,9 @@ class OauthClient:
Synchronous client for handling Openapi authentication and token management.
"""

def __init__(self, username: str, apikey: str, test: bool = False):
self.client = httpx.Client()
def __init__(self, username: str, apikey: str, test: bool = False, client: Any = None):
self._client = client
self._thread_local = threading.local()
self.url: str = TEST_OAUTH_BASE_URL if test else OAUTH_BASE_URL
self.auth_header: str = (
"Basic " + base64.b64encode(f"{username}:{apikey}".encode("utf-8")).decode()
Expand All @@ -23,6 +25,23 @@ def __init__(self, username: str, apikey: str, test: bool = False):
"Content-Type": "application/json",
}

@property
def client(self) -> Any:
"""
Thread-safe access to the underlying HTTP client.
If a custom client was provided at initialization, it is returned.
Otherwise, a thread-local httpx.Client is created and returned.
"""
if self._client is not None:
return self._client
if not hasattr(self._thread_local, "client"):
self._thread_local.client = httpx.Client()
return self._thread_local.client

@client.setter
def client(self, value: Any):
self._client = value

def __enter__(self):
"""Enable use as a synchronous context manager."""
return self
Expand Down
10 changes: 10 additions & 0 deletions tests/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ async def test_get_scopes(self, mock_httpx):
await oauth.aclose()
mock_httpx.return_value.aclose.assert_called_once()

def test_custom_client_transport(self):
custom_client = MagicMock()
oauth = AsyncOauthClient(username="user", apikey="key", client=custom_client)
self.assertEqual(oauth.client, custom_client)


class TestAsyncClient(unittest.IsolatedAsyncioTestCase):
"""
Expand Down Expand Up @@ -85,6 +90,11 @@ async def test_request_post(self, mock_httpx):
await client.aclose()
mock_httpx.return_value.aclose.assert_called_once()

def test_custom_client_transport(self):
custom_client = MagicMock()
client = AsyncClient(token="abc123", client=custom_client)
self.assertEqual(client.client, custom_client)


if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def test_auth_header_is_basic(self, mock_httpx):
oauth = OauthClient(username="user", apikey="key")
self.assertTrue(oauth.auth_header.startswith("Basic "))

def test_custom_client_transport(self):
custom_client = MagicMock()
oauth = OauthClient(username="user", apikey="key", client=custom_client)
self.assertEqual(oauth.client, custom_client)


class TestClient(unittest.TestCase):

Expand Down Expand Up @@ -109,6 +114,11 @@ def test_defaults_on_empty_request(self, mock_httpx):
method="GET", url="", headers=client.headers, json={}, params={}
)

def test_custom_client_transport(self):
custom_client = MagicMock()
client = Client(token="tok", client=custom_client)
self.assertEqual(client.client, custom_client)


if __name__ == "__main__":
unittest.main()
65 changes: 65 additions & 0 deletions tests/test_thread_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import threading
import unittest

import httpx

from openapi_python_sdk import Client, OauthClient


class TestThreadSafety(unittest.TestCase):
def test_oauth_client_thread_safety(self):
oauth = OauthClient(username="user", apikey="key")

clients = []
def get_client():
clients.append(oauth.client)

threads = [threading.Thread(target=get_client) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()

# Each thread should have gotten a unique client instance
self.assertEqual(len(clients), 5)
self.assertEqual(len(set(id(c) for c in clients)), 5)

def test_client_thread_safety(self):
client = Client(token="tok")

clients = []
def get_client():
clients.append(client.client)

threads = [threading.Thread(target=get_client) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()

# Each thread should have gotten a unique client instance
self.assertEqual(len(clients), 5)
self.assertEqual(len(set(id(c) for c in clients)), 5)

def test_shared_client_injection_still_works(self):
# If we explicitly pass a client, it SHOULD be shared (backward compatibility)
shared_engine = httpx.Client()
oauth = OauthClient(username="user", apikey="key", client=shared_engine)

clients = []
def get_client():
clients.append(oauth.client)

threads = [threading.Thread(target=get_client) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()

# All threads should have the SAME instance because it was injected
self.assertEqual(len(clients), 5)
self.assertEqual(len(set(id(c) for c in clients)), 1)
self.assertEqual(id(clients[0]), id(shared_engine))

if __name__ == "__main__":
unittest.main()
Loading