Skip to content

Commit a6a3564

Browse files
committed
Implemented the functionality of poll_pre_ping which executes a query to see if the connection is alive. If not it'll recycle the session
1 parent 6b80531 commit a6a3564

3 files changed

Lines changed: 127 additions & 0 deletions

File tree

src/databricks/sqlalchemy/base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,31 @@ def do_rollback(self, dbapi_connection):
336336
# Databricks SQL Does not support transactions
337337
pass
338338

339+
def do_ping(self, dbapi_connection):
340+
"""Test if a database connection is alive.
341+
342+
This method is called by SQLAlchemy when pool_pre_ping=True to verify
343+
connections are still valid before using them from the pool.
344+
345+
Args:
346+
dbapi_connection: A raw DBAPI connection (from databricks-sql-connector)
347+
348+
Returns:
349+
True if the connection is alive, False otherwise.
350+
"""
351+
try:
352+
cursor = dbapi_connection.cursor()
353+
try:
354+
cursor.execute("SELECT VERSION()")
355+
cursor.fetchone()
356+
return True
357+
finally:
358+
cursor.close()
359+
except Exception:
360+
# Any exception means the connection is dead
361+
# SQLAlchemy will discard it and create a new one
362+
return False
363+
339364
@reflection.cache
340365
def has_table(
341366
self, connection, table_name, schema=None, catalog=None, **kwargs

tests/test_local/e2e/test_basic.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,3 +541,58 @@ def test_table_comment_reflection(self, inspector: Inspector, table: Table):
541541
def test_column_comment(self, inspector: Inspector, table: Table):
542542
result = inspector.get_columns(table.name)[0].get("comment")
543543
assert result == "column comment"
544+
545+
546+
def test_pool_pre_ping_with_closed_connection(connection_details):
547+
"""Test that pool_pre_ping detects closed connections and creates new ones.
548+
549+
This test verifies that when a connection is manually closed (simulating
550+
session expiration), pool_pre_ping detects it and automatically creates
551+
a new connection without raising an error to the user.
552+
"""
553+
conn_string, connect_args = version_agnostic_connect_arguments(connection_details)
554+
555+
# Create engine with pool_pre_ping enabled
556+
engine = create_engine(
557+
conn_string,
558+
connect_args=connect_args,
559+
pool_pre_ping=True,
560+
pool_size=1,
561+
max_overflow=0
562+
)
563+
564+
# Step 1: Create connection and get session ID
565+
with engine.connect() as conn:
566+
result = conn.execute(text("SELECT VERSION()")).scalar()
567+
assert result is not None
568+
569+
# Get session ID of first connection
570+
raw_conn = conn.connection.dbapi_connection
571+
session_id_1 = raw_conn.get_session_id_hex()
572+
assert session_id_1 is not None
573+
574+
# Step 2: Manually close the connection to simulate expiration
575+
pooled_conn = engine.pool._pool.queue[0]
576+
pooled_conn.driver_connection.close()
577+
578+
# Verify connection is closed
579+
assert not pooled_conn.driver_connection.open
580+
581+
# Step 3: Try to use the closed connection - pool_pre_ping should detect and recycle
582+
with engine.connect() as conn:
583+
result = conn.execute(text("SELECT VERSION()")).scalar()
584+
assert result is not None
585+
586+
# Get session ID of new connection
587+
raw_conn = conn.connection.dbapi_connection
588+
session_id_2 = raw_conn.get_session_id_hex()
589+
assert session_id_2 is not None
590+
591+
# Verify a NEW connection was created (different session ID)
592+
assert session_id_1 != session_id_2, (
593+
"pool_pre_ping should have detected the closed connection "
594+
"and created a new one with a different session ID"
595+
)
596+
597+
# Cleanup
598+
engine.dispose()

tests/test_local/test_ping.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Unit tests for do_ping() method in DatabricksDialect."""
2+
import pytest
3+
from unittest.mock import Mock
4+
from databricks.sqlalchemy import DatabricksDialect
5+
6+
7+
class TestDoPing:
8+
"""Test the do_ping() method for connection health checks."""
9+
10+
@pytest.fixture
11+
def dialect(self):
12+
"""Create a DatabricksDialect instance."""
13+
return DatabricksDialect()
14+
15+
def test_do_ping_success(self, dialect):
16+
"""Test do_ping returns True when connection is alive."""
17+
mock_connection = Mock()
18+
mock_cursor = Mock()
19+
mock_connection.cursor.return_value = mock_cursor
20+
21+
result = dialect.do_ping(mock_connection)
22+
23+
assert result is True
24+
mock_cursor.execute.assert_called_once_with("SELECT VERSION()")
25+
mock_cursor.fetchone.assert_called_once()
26+
mock_cursor.close.assert_called_once()
27+
28+
def test_do_ping_failure_cursor_creation(self, dialect):
29+
"""Test do_ping returns False when cursor creation fails."""
30+
mock_connection = Mock()
31+
mock_connection.cursor.side_effect = Exception("Connection closed")
32+
33+
result = dialect.do_ping(mock_connection)
34+
35+
assert result is False
36+
37+
def test_do_ping_failure_execute_and_cursor_closes(self, dialect):
38+
"""Test do_ping returns False on execute error and cursor is closed."""
39+
mock_connection = Mock()
40+
mock_cursor = Mock()
41+
mock_connection.cursor.return_value = mock_cursor
42+
mock_cursor.execute.side_effect = Exception("Query failed")
43+
44+
result = dialect.do_ping(mock_connection)
45+
46+
assert result is False
47+
mock_cursor.close.assert_called_once()

0 commit comments

Comments
 (0)