Skip to content

Commit c8cabc4

Browse files
committed
add http client connection pooling for efficient connection reuse
1 parent bae6ae2 commit c8cabc4

5 files changed

Lines changed: 114 additions & 39 deletions

File tree

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .config import configure, set_config_provider
2+
from .pool import client_scope, close_clients
23

3-
__all__ = ["configure", "set_config_provider"]
4+
__all__ = ["configure", "set_config_provider", "client_scope", "close_clients"]

src/bubble_data_api_client/client/orm.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,11 @@
66
from pydantic import Field
77

88
from bubble_data_api_client.client.raw_client import RawClient
9-
from bubble_data_api_client.config import get_config
109
from bubble_data_api_client.constraints import ConstraintTypes, constraint
1110

1211

1312
def _get_client() -> RawClient:
14-
config = get_config()
15-
api_root = config.get("data_api_root_url")
16-
api_key = config.get("api_key")
17-
if not api_root:
18-
raise RuntimeError("data_api_root_url")
19-
if not api_key:
20-
raise RuntimeError("api_key")
21-
return RawClient(data_api_root_url=api_root, api_key=api_key)
13+
return RawClient()
2214

2315

2416
class BubbleBaseModel(PydanticBaseModel):

src/bubble_data_api_client/client/raw_client.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,13 @@ class RawClient:
2424
https://www.postman.com/bubbleapi/bubble/request/jigyk5v/
2525
"""
2626

27-
_data_api_root_url: str
28-
_api_key: str
2927
_transport: Transport
3028

31-
def __init__(
32-
self,
33-
data_api_root_url: str,
34-
api_key: str,
35-
):
36-
self._data_api_root_url = data_api_root_url
37-
self._api_key = api_key
29+
def __init__(self) -> None:
30+
pass
3831

3932
async def __aenter__(self) -> typing.Self:
40-
self._transport = Transport(
41-
base_url=self._data_api_root_url,
42-
api_key=self._api_key,
43-
)
33+
self._transport = Transport()
4434
await self._transport.__aenter__()
4535
return self
4636

src/bubble_data_api_client/pool.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Client pool for efficient connection reuse."""
2+
3+
import asyncio
4+
import atexit
5+
import threading
6+
from collections.abc import AsyncIterator
7+
from contextlib import asynccontextmanager
8+
from typing import Any
9+
10+
import httpx
11+
12+
from .config import get_config
13+
from .transport import httpx_client_factory
14+
15+
# global client pool keyed by config
16+
_clients: dict[tuple[str, str], httpx.AsyncClient] = {}
17+
_lock = threading.Lock()
18+
19+
20+
def _make_client_key(config: dict[str, Any]) -> tuple[str, str]:
21+
"""Generate a unique key for client pooling based on config."""
22+
base_url = config.get("data_api_root_url") or ""
23+
api_key = config.get("api_key") or ""
24+
return (base_url, api_key)
25+
26+
27+
def get_client() -> httpx.AsyncClient:
28+
"""Get or create a client for the current config. Thread-safe."""
29+
config = get_config()
30+
key = _make_client_key(config)
31+
32+
# fast path: no lock if client exists
33+
if key in _clients:
34+
return _clients[key]
35+
36+
# slow path: acquire lock for creation
37+
with _lock:
38+
# double-check after acquiring lock
39+
if key not in _clients:
40+
base_url = config.get("data_api_root_url")
41+
if not base_url:
42+
raise RuntimeError("data_api_root_url")
43+
api_key = config.get("api_key")
44+
if not api_key:
45+
raise RuntimeError("api_key")
46+
_clients[key] = httpx_client_factory(base_url=base_url, api_key=api_key)
47+
return _clients[key]
48+
49+
50+
async def close_clients() -> None:
51+
"""Close all clients in the pool. Thread-safe. Safe to call multiple times."""
52+
with _lock:
53+
clients_to_close = list(_clients.values())
54+
_clients.clear()
55+
56+
for client in clients_to_close:
57+
await client.aclose()
58+
59+
60+
@asynccontextmanager
61+
async def client_scope() -> AsyncIterator[None]:
62+
"""Scope that ensures close_clients() is called on exit."""
63+
try:
64+
yield
65+
finally:
66+
await close_clients()
67+
68+
69+
def _atexit_cleanup() -> None:
70+
"""Best-effort cleanup at interpreter exit."""
71+
with _lock:
72+
clients_to_close = list(_clients.values())
73+
_clients.clear()
74+
75+
if not clients_to_close:
76+
return
77+
78+
# check if there's already a running loop
79+
try:
80+
running_loop = asyncio.get_running_loop()
81+
except RuntimeError:
82+
running_loop = None
83+
84+
try:
85+
if running_loop is not None:
86+
# loop still running at atexit, schedule cleanup tasks
87+
for client in clients_to_close:
88+
running_loop.create_task(client.aclose())
89+
else:
90+
# no running loop, create one for cleanup
91+
loop = asyncio.new_event_loop()
92+
for client in clients_to_close:
93+
loop.run_until_complete(client.aclose())
94+
loop.close()
95+
except Exception:
96+
pass
97+
98+
99+
atexit.register(_atexit_cleanup)

src/bubble_data_api_client/transport.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,28 +24,21 @@ def httpx_client_factory(
2424
class Transport:
2525
"""
2626
Transport layer focuses on HTTP.
27-
- manage connections
28-
- authentication
29-
- headers
30-
- retries, backoff
31-
- timeouts
32-
- exposes errors to the client
27+
- authentication, headers, retries, timeouts: configured via httpx_client_factory
28+
- connection lifecycle: managed by pool module
29+
- HTTP verb methods: get, post, patch, put, delete
30+
- error handling: raise_for_status on responses
3331
"""
3432

35-
_base_url: str
36-
_api_key: str
3733
_http: httpx.AsyncClient
3834

39-
def __init__(self, base_url: str, api_key: str):
40-
self._base_url = base_url
41-
self._api_key = api_key
35+
def __init__(self) -> None:
36+
pass
4237

4338
async def __aenter__(self) -> typing.Self:
44-
self._http = httpx_client_factory(
45-
base_url=self._base_url,
46-
api_key=self._api_key,
47-
)
39+
from .pool import get_client
4840

41+
self._http = get_client()
4942
return self
5043

5144
async def __aexit__(
@@ -54,7 +47,7 @@ async def __aexit__(
5447
exc_val: BaseException | None,
5548
exc_tb: types.TracebackType | None,
5649
) -> None:
57-
await self._http.aclose()
50+
pass
5851

5952
async def request(
6053
self,

0 commit comments

Comments
 (0)