Skip to content

Commit 7e8af3e

Browse files
authored
fix: race condition in connection pool initialization (#346) (#347)
Fixes #346. Implements double-checked locking in `AsyncDatabaseConfig` and `SyncDatabaseConfig` to ensure thread-safe and coroutine-safe pool initialization. This prevents the issue where concurrent calls to `provide_pool()` (or implicit pool creation via session/connection provision) could create multiple pool instances, leading to 'invalid connection' errors when releasing connections. Added `tests/integration/test_pool_concurrency.py` which reliably reproduces the race condition (failing before this fix) and verifies the fix.
1 parent 00c2f84 commit 7e8af3e

2 files changed

Lines changed: 104 additions & 14 deletions

File tree

sqlspec/config.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import threading
13
from abc import ABC, abstractmethod
24
from collections.abc import Callable
35
from inspect import Signature, signature
@@ -1510,7 +1512,7 @@ async def fix_migrations(self, dry_run: bool = False, update_database: bool = Tr
15101512
class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
15111513
"""Base class for sync database configurations with connection pooling."""
15121514

1513-
__slots__ = ("connection_config",)
1515+
__slots__ = ("_pool_lock", "connection_config")
15141516
is_async: "ClassVar[bool]" = False
15151517
supports_connection_pooling: "ClassVar[bool]" = True
15161518
migration_tracker_type: "ClassVar[type[Any]]" = SyncMigrationTracker
@@ -1549,6 +1551,7 @@ def __init__(
15491551
self.driver_features.setdefault("storage_capabilities", self.storage_capabilities())
15501552
self._promote_driver_feature_hooks()
15511553
self._configure_observability_extensions()
1554+
self._pool_lock = threading.Lock()
15521555

15531556
def create_pool(self) -> PoolT:
15541557
"""Create and return the connection pool.
@@ -1558,9 +1561,14 @@ def create_pool(self) -> PoolT:
15581561
"""
15591562
if self.connection_instance is not None:
15601563
return self.connection_instance
1561-
self.connection_instance = self._create_pool()
1562-
self.get_observability_runtime().emit_pool_create(self.connection_instance)
1563-
return self.connection_instance
1564+
1565+
with self._pool_lock:
1566+
if self.connection_instance is not None:
1567+
return self.connection_instance
1568+
1569+
self.connection_instance = self._create_pool()
1570+
self.get_observability_runtime().emit_pool_create(self.connection_instance)
1571+
return self.connection_instance
15641572

15651573
def close_pool(self) -> None:
15661574
"""Close the connection pool."""
@@ -1572,9 +1580,7 @@ def close_pool(self) -> None:
15721580

15731581
def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT:
15741582
"""Provide pool instance."""
1575-
if self.connection_instance is None:
1576-
self.connection_instance = self.create_pool()
1577-
return self.connection_instance
1583+
return self.create_pool()
15781584

15791585
def create_connection(self) -> ConnectionT:
15801586
"""Create a database connection."""
@@ -1709,7 +1715,7 @@ def fix_migrations(self, dry_run: bool = False, update_database: bool = True, ye
17091715
class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
17101716
"""Base class for async database configurations with connection pooling."""
17111717

1712-
__slots__ = ("connection_config",)
1718+
__slots__ = ("_pool_lock", "connection_config")
17131719
is_async: "ClassVar[bool]" = True
17141720
supports_connection_pooling: "ClassVar[bool]" = True
17151721
migration_tracker_type: "ClassVar[type[Any]]" = AsyncMigrationTracker
@@ -1750,6 +1756,7 @@ def __init__(
17501756
self.driver_features.setdefault("storage_capabilities", self.storage_capabilities())
17511757
self._promote_driver_feature_hooks()
17521758
self._configure_observability_extensions()
1759+
self._pool_lock = asyncio.Lock()
17531760

17541761
async def create_pool(self) -> PoolT:
17551762
"""Create and return the connection pool.
@@ -1759,9 +1766,14 @@ async def create_pool(self) -> PoolT:
17591766
"""
17601767
if self.connection_instance is not None:
17611768
return self.connection_instance
1762-
self.connection_instance = await self._create_pool()
1763-
self.get_observability_runtime().emit_pool_create(self.connection_instance)
1764-
return self.connection_instance
1769+
1770+
async with self._pool_lock:
1771+
if self.connection_instance is not None:
1772+
return self.connection_instance
1773+
1774+
self.connection_instance = await self._create_pool()
1775+
self.get_observability_runtime().emit_pool_create(self.connection_instance)
1776+
return self.connection_instance
17651777

17661778
async def close_pool(self) -> None:
17671779
"""Close the connection pool."""
@@ -1773,9 +1785,7 @@ async def close_pool(self) -> None:
17731785

17741786
async def provide_pool(self, *args: Any, **kwargs: Any) -> PoolT:
17751787
"""Provide pool instance."""
1776-
if self.connection_instance is None:
1777-
self.connection_instance = await self.create_pool()
1778-
return self.connection_instance
1788+
return await self.create_pool()
17791789

17801790
async def create_connection(self) -> ConnectionT:
17811791
"""Create a database connection."""
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import threading
5+
from typing import TYPE_CHECKING
6+
7+
import pytest
8+
9+
from sqlspec.adapters.asyncpg import AsyncpgConfig
10+
from sqlspec.adapters.duckdb import DuckDBConfig
11+
12+
if TYPE_CHECKING:
13+
from pytest_databases.docker.postgres import PostgresService
14+
15+
from sqlspec.adapters.asyncpg import AsyncpgPool
16+
from sqlspec.adapters.duckdb import DuckDBConnectionPool
17+
18+
19+
@pytest.mark.asyncio
20+
async def test_asyncpg_pool_concurrency(postgres_service: PostgresService) -> None:
21+
"""Verify that multiple concurrent calls to provide_pool result in a single pool."""
22+
config_params = {
23+
"host": postgres_service.host,
24+
"port": postgres_service.port,
25+
"user": postgres_service.user,
26+
"password": postgres_service.password,
27+
"database": postgres_service.database,
28+
}
29+
# Initialize with connection_instance=None explicitly just to be sure
30+
config = AsyncpgConfig(connection_config=config_params, connection_instance=None)
31+
32+
async def get_pool() -> AsyncpgPool:
33+
# Artificial delay to ensure tasks overlap in checking connection_instance
34+
# This simulates the "check" part of check-then-act overlapping
35+
return await config.provide_pool()
36+
37+
# Launch many tasks simultaneously
38+
tasks = [get_pool() for _ in range(50)]
39+
pools = await asyncio.gather(*tasks)
40+
41+
# All pools should be the exact same object
42+
first_pool = pools[0]
43+
unique_pools = {id(p) for p in pools}
44+
45+
await config.close_pool()
46+
47+
assert len(unique_pools) == 1, f"Race condition detected! {len(unique_pools)} unique pools created."
48+
assert all(p is first_pool for p in pools)
49+
50+
51+
def test_duckdb_pool_concurrency() -> None:
52+
"""Verify that multiple concurrent calls to provide_pool result in a single pool (Sync)."""
53+
# Use shared memory db for valid concurrency test
54+
config = DuckDBConfig(connection_config={"database": ":memory:"})
55+
56+
# We need to capture results from threads
57+
results: list[DuckDBConnectionPool | None] = [None] * 50
58+
exceptions: list[Exception] = []
59+
60+
def get_pool(index: int) -> None:
61+
try:
62+
pool = config.provide_pool()
63+
results[index] = pool
64+
except Exception as e:
65+
exceptions.append(e)
66+
67+
threads = [threading.Thread(target=get_pool, args=(i,)) for i in range(50)]
68+
69+
for t in threads:
70+
t.start()
71+
for t in threads:
72+
t.join()
73+
74+
if exceptions:
75+
pytest.fail(f"Exceptions in threads: {exceptions}")
76+
77+
unique_pools = {id(p) for p in results if p is not None}
78+
config.close_pool()
79+
80+
assert len(unique_pools) == 1, f"Race condition detected! {len(unique_pools)} unique DuckDB pools created."

0 commit comments

Comments
 (0)