diff --git a/README.md b/README.md index 39db3d5..b48d866 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,26 @@ resp = client.request( ) ``` +### Customizing the Transport Layer + +If you need to configure custom retry logic, proxies, or use a different HTTP client (such as passing a `requests.Session` with a custom urllib3 `Retry`), you can inject it directly using the `client` parameter on any SDK class: + +```python +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry +from openapi_python_sdk import Client +import requests + +retry = Retry(total=3) +adapter = HTTPAdapter(max_retries=retry) + +session = requests.Session() +session.mount("https://", adapter) + +# Pass the custom session to the Client explicitly +client = Client("token", client=session) +``` + ## Async Usage The SDK provides `AsyncClient` and `AsyncOauthClient` for use with asynchronous frameworks like FastAPI or `aiohttp`. diff --git a/openapi_python_sdk/async_client.py b/openapi_python_sdk/async_client.py index 7ff6e6a..e0d0a05 100644 --- a/openapi_python_sdk/async_client.py +++ b/openapi_python_sdk/async_client.py @@ -10,8 +10,8 @@ class AsyncClient: Suitable for use with FastAPI, aiohttp, etc. """ - def __init__(self, token: str): - self.client = httpx.AsyncClient() + def __init__(self, token: str, client: Any = None): + self.client = client if client is not None else httpx.AsyncClient() self.auth_header: str = f"Bearer {token}" self.headers: Dict[str, str] = { "Authorization": self.auth_header, @@ -43,6 +43,13 @@ async def request( payload = payload or {} params = params or {} url = url or "" + + if params: + import urllib.parse + query_string = urllib.parse.urlencode(params, doseq=True) + url = f"{url}&{query_string}" if "?" in url else f"{url}?{query_string}" + params = None + resp = await self.client.request( method=method, url=url, diff --git a/openapi_python_sdk/async_oauth_client.py b/openapi_python_sdk/async_oauth_client.py index 2633c93..c3234cf 100644 --- a/openapi_python_sdk/async_oauth_client.py +++ b/openapi_python_sdk/async_oauth_client.py @@ -12,8 +12,8 @@ class AsyncOauthClient: Suitable for use with FastAPI, aiohttp, etc. """ - def __init__(self, username: str, apikey: str, test: bool = False): - self.client = httpx.AsyncClient() + def __init__(self, username: str, apikey: str, test: bool = False, client: Any = None): + self.client = client if client is not None else httpx.AsyncClient() self.url: str = TEST_OAUTH_BASE_URL if test else OAUTH_BASE_URL self.auth_header: str = ( "Basic " + base64.b64encode(f"{username}:{apikey}".encode("utf-8")).decode() diff --git a/openapi_python_sdk/client.py b/openapi_python_sdk/client.py index 16956cf..5393fda 100644 --- a/openapi_python_sdk/client.py +++ b/openapi_python_sdk/client.py @@ -1,4 +1,5 @@ import json +import threading from typing import Any, Dict import httpx @@ -14,14 +15,33 @@ class Client: Synchronous client for making authenticated requests to Openapi endpoints. """ - def __init__(self, token: str): - self.client = httpx.Client() + def __init__(self, token: str, client: Any = None): + self._client = client + self._thread_local = threading.local() self.auth_header: str = f"Bearer {token}" self.headers: Dict[str, str] = { "Authorization": self.auth_header, "Content-Type": "application/json", } + @property + def client(self) -> Any: + """ + Thread-safe access to the underlying HTTP client. + If a custom client was provided at initialization, it is returned. + Otherwise, a thread-local httpx.Client is created and returned. + """ + if self._client is not None: + return self._client + + if not hasattr(self._thread_local, "client"): + self._thread_local.client = httpx.Client() + return self._thread_local.client + + @client.setter + def client(self, value: Any): + self._client = value + def __enter__(self): """Enable use as a synchronous context manager.""" return self @@ -47,6 +67,13 @@ def request( payload = payload or {} params = params or {} url = url or "" + + if params: + import urllib.parse + query_string = urllib.parse.urlencode(params, doseq=True) + url = f"{url}&{query_string}" if "?" in url else f"{url}?{query_string}" + params = None + data = self.client.request( method=method, url=url, diff --git a/openapi_python_sdk/oauth_client.py b/openapi_python_sdk/oauth_client.py index 8c8c446..bcb44da 100644 --- a/openapi_python_sdk/oauth_client.py +++ b/openapi_python_sdk/oauth_client.py @@ -1,4 +1,5 @@ import base64 +import threading from typing import Any, Dict, List import httpx @@ -12,8 +13,9 @@ class OauthClient: Synchronous client for handling Openapi authentication and token management. """ - def __init__(self, username: str, apikey: str, test: bool = False): - self.client = httpx.Client() + def __init__(self, username: str, apikey: str, test: bool = False, client: Any = None): + self._client = client + self._thread_local = threading.local() self.url: str = TEST_OAUTH_BASE_URL if test else OAUTH_BASE_URL self.auth_header: str = ( "Basic " + base64.b64encode(f"{username}:{apikey}".encode("utf-8")).decode() @@ -23,6 +25,23 @@ def __init__(self, username: str, apikey: str, test: bool = False): "Content-Type": "application/json", } + @property + def client(self) -> Any: + """ + Thread-safe access to the underlying HTTP client. + If a custom client was provided at initialization, it is returned. + Otherwise, a thread-local httpx.Client is created and returned. + """ + if self._client is not None: + return self._client + if not hasattr(self._thread_local, "client"): + self._thread_local.client = httpx.Client() + return self._thread_local.client + + @client.setter + def client(self, value: Any): + self._client = value + def __enter__(self): """Enable use as a synchronous context manager.""" return self diff --git a/tests/test_async_client.py b/tests/test_async_client.py index d25454b..461305b 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -42,6 +42,11 @@ async def test_get_scopes(self, mock_httpx): await oauth.aclose() mock_httpx.return_value.aclose.assert_called_once() + def test_custom_client_transport(self): + custom_client = MagicMock() + oauth = AsyncOauthClient(username="user", apikey="key", client=custom_client) + self.assertEqual(oauth.client, custom_client) + class TestAsyncClient(unittest.IsolatedAsyncioTestCase): """ @@ -85,6 +90,11 @@ async def test_request_post(self, mock_httpx): await client.aclose() mock_httpx.return_value.aclose.assert_called_once() + def test_custom_client_transport(self): + custom_client = MagicMock() + client = AsyncClient(token="abc123", client=custom_client) + self.assertEqual(client.client, custom_client) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_client.py b/tests/test_client.py index 75fe1e4..3b8f874 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -56,6 +56,11 @@ def test_auth_header_is_basic(self, mock_httpx): oauth = OauthClient(username="user", apikey="key") self.assertTrue(oauth.auth_header.startswith("Basic ")) + def test_custom_client_transport(self): + custom_client = MagicMock() + oauth = OauthClient(username="user", apikey="key", client=custom_client) + self.assertEqual(oauth.client, custom_client) + class TestClient(unittest.TestCase): @@ -109,6 +114,11 @@ def test_defaults_on_empty_request(self, mock_httpx): method="GET", url="", headers=client.headers, json={}, params={} ) + def test_custom_client_transport(self): + custom_client = MagicMock() + client = Client(token="tok", client=custom_client) + self.assertEqual(client.client, custom_client) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_thread_safety.py b/tests/test_thread_safety.py new file mode 100644 index 0000000..5763084 --- /dev/null +++ b/tests/test_thread_safety.py @@ -0,0 +1,65 @@ +import threading +import unittest + +import httpx + +from openapi_python_sdk import Client, OauthClient + + +class TestThreadSafety(unittest.TestCase): + def test_oauth_client_thread_safety(self): + oauth = OauthClient(username="user", apikey="key") + + clients = [] + def get_client(): + clients.append(oauth.client) + + threads = [threading.Thread(target=get_client) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Each thread should have gotten a unique client instance + self.assertEqual(len(clients), 5) + self.assertEqual(len(set(id(c) for c in clients)), 5) + + def test_client_thread_safety(self): + client = Client(token="tok") + + clients = [] + def get_client(): + clients.append(client.client) + + threads = [threading.Thread(target=get_client) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Each thread should have gotten a unique client instance + self.assertEqual(len(clients), 5) + self.assertEqual(len(set(id(c) for c in clients)), 5) + + def test_shared_client_injection_still_works(self): + # If we explicitly pass a client, it SHOULD be shared (backward compatibility) + shared_engine = httpx.Client() + oauth = OauthClient(username="user", apikey="key", client=shared_engine) + + clients = [] + def get_client(): + clients.append(oauth.client) + + threads = [threading.Thread(target=get_client) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All threads should have the SAME instance because it was injected + self.assertEqual(len(clients), 5) + self.assertEqual(len(set(id(c) for c in clients)), 1) + self.assertEqual(id(clients[0]), id(shared_engine)) + +if __name__ == "__main__": + unittest.main()