Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions tests/unit/sqlalchemy/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,100 @@ def test_isolation_level(self):
isolation_level = self.dialect.get_isolation_level(dbapi_conn)
assert isolation_level == "SERIALIZABLE"

@staticmethod
def _partitions_connection(partition_names, data_types):
"""Build a mock SQLAlchemy connection whose execute() returns a result.

The result exposes ``cursor.description`` shaped like the given
partition column names/types and supports both the SQLAlchemy 2.0
context-manager path and the legacy ``fetchall``/``close`` path.
"""
result = mock.MagicMock()
result.cursor.description = list(zip(partition_names, data_types))
result.closed = False
# Support the ``with connection.execute(...) as conn:`` (2.0) path.
result.__enter__.return_value = result
result.__exit__.return_value = False

connection = mock.Mock()
connection.execute.return_value = result
return connection, result

def test_get_partitions_hive_table_returns_partition_names(self):
partition_names = ["year", "month"]
data_types = ["integer", "integer"]
connection, _ = self._partitions_connection(partition_names, data_types)

result = self.dialect._get_partitions(connection, "some_table", "some_schema")

assert result == partition_names

def test_get_partitions_iceberg_table_returns_none(self):
# An Iceberg ``$partitions`` table has this exact shape and must be
# recognized as metadata rather than treated as partition columns.
partition_names = ["partition", "record_count", "file_count", "total_size", "data"]
data_types = ["row(...)", "bigint", "bigint", "bigint", "row(...)"]
connection, _ = self._partitions_connection(partition_names, data_types)

result = self.dialect._get_partitions(connection, "some_table", "some_schema")

assert result is None

def test_get_partitions_uses_context_manager_on_sqlalchemy_2(self):
partition_names = ["year"]
data_types = ["integer"]
connection, result = self._partitions_connection(partition_names, data_types)

with mock.patch("trino.sqlalchemy.dialect.SQLALCHEMY_VERSION", "2.0.0"):
self.dialect._get_partitions(connection, "some_table", "some_schema")

# 2.0 path consumes the result as a context manager and never falls
# back to manual fetch/close.
result.__enter__.assert_called_once()
result.__exit__.assert_called_once()
result.fetchall.assert_not_called()
result.close.assert_not_called()

def test_get_partitions_fetches_and_closes_on_legacy_sqlalchemy(self):
partition_names = ["year"]
data_types = ["integer"]
connection, result = self._partitions_connection(partition_names, data_types)

with mock.patch("trino.sqlalchemy.dialect.SQLALCHEMY_VERSION", "1.4.0"):
self.dialect._get_partitions(connection, "some_table", "some_schema")

# Legacy path must drain the cursor so the query is FINISHED (not
# CANCELED) and then close it explicitly.
result.fetchall.assert_called_once()
result.close.assert_called_once()
result.__enter__.assert_not_called()

def test_get_partitions_skips_close_when_already_closed_on_legacy(self):
partition_names = ["year"]
data_types = ["integer"]
connection, result = self._partitions_connection(partition_names, data_types)
result.closed = True

with mock.patch("trino.sqlalchemy.dialect.SQLALCHEMY_VERSION", "1.4.0"):
self.dialect._get_partitions(connection, "some_table", "some_schema")

result.close.assert_not_called()

def test_get_partitions_defaults_schema_when_not_given(self):
partition_names = ["year"]
data_types = ["integer"]
connection, _ = self._partitions_connection(partition_names, data_types)

with mock.patch.object(
self.dialect, "_get_default_schema_name", return_value="default_schema"
) as get_default_schema:
self.dialect._get_partitions(connection, "some_table")

get_default_schema.assert_called_once_with(connection)
# The defaulted schema must appear in the issued query.
query = str(connection.execute.call_args[0][0])
assert "default_schema" in query


def test_trino_connection_basic_auth():
dialect = TrinoDialect()
Expand Down
18 changes: 15 additions & 3 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from sqlalchemy import exc
from sqlalchemy import sql
from sqlalchemy import __version__ as SQLALCHEMY_VERSION
from sqlalchemy.engine import Engine
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.default import DefaultDialect
Expand Down Expand Up @@ -221,9 +222,20 @@ def _get_partitions(
SELECT * FROM {schema}."{table_name}$partitions"
"""
).strip()
res = connection.execute(sql.text(query))
partition_names = [desc[0] for desc in res.cursor.description]
data_types = [desc[1] for desc in res.cursor.description]

if SQLALCHEMY_VERSION >= '2.0.0':
# Lower versions of SQLAlchemy doesn't support execution of query as ctxt man
with connection.execute(sql.text(query)) as conn:
partition_names = [desc[0] for desc in conn.cursor.description]
data_types = [desc[1] for desc in conn.cursor.description]
else:
conn = connection.execute(sql.text(query))
partition_names = [desc[0] for desc in conn.cursor.description]
data_types = [desc[1] for desc in conn.cursor.description]
res = conn.fetchall() # Fetching data to consider query as FINISHED and not CANCELED
if not conn.closed:
conn.close()

# Compare the column names and types to the shape of an Iceberg $partitions table
if (partition_names == ['partition', 'record_count', 'file_count', 'total_size', 'data']
and data_types[0].startswith('row(')
Expand Down