diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py index d247a536..1d230865 100644 --- a/tests/unit/sqlalchemy/test_dialect.py +++ b/tests/unit/sqlalchemy/test_dialect.py @@ -258,6 +258,100 @@ def test_isolation_level(self): isolation_level = self.dialect.get_isolation_level(dbapi_conn) assert isolation_level == "SERIALIZABLE" + @staticmethod + def _partitions_connection(partition_names, data_types): + """Build a mock SQLAlchemy connection whose execute() returns a result. + + The result exposes ``cursor.description`` shaped like the given + partition column names/types and supports both the SQLAlchemy 2.0 + context-manager path and the legacy ``fetchall``/``close`` path. + """ + result = mock.MagicMock() + result.cursor.description = list(zip(partition_names, data_types)) + result.closed = False + # Support the ``with connection.execute(...) as conn:`` (2.0) path. + result.__enter__.return_value = result + result.__exit__.return_value = False + + connection = mock.Mock() + connection.execute.return_value = result + return connection, result + + def test_get_partitions_hive_table_returns_partition_names(self): + partition_names = ["year", "month"] + data_types = ["integer", "integer"] + connection, _ = self._partitions_connection(partition_names, data_types) + + result = self.dialect._get_partitions(connection, "some_table", "some_schema") + + assert result == partition_names + + def test_get_partitions_iceberg_table_returns_none(self): + # An Iceberg ``$partitions`` table has this exact shape and must be + # recognized as metadata rather than treated as partition columns. + partition_names = ["partition", "record_count", "file_count", "total_size", "data"] + data_types = ["row(...)", "bigint", "bigint", "bigint", "row(...)"] + connection, _ = self._partitions_connection(partition_names, data_types) + + result = self.dialect._get_partitions(connection, "some_table", "some_schema") + + assert result is None + + def test_get_partitions_uses_context_manager_on_sqlalchemy_2(self): + partition_names = ["year"] + data_types = ["integer"] + connection, result = self._partitions_connection(partition_names, data_types) + + with mock.patch("trino.sqlalchemy.dialect.SQLALCHEMY_VERSION", "2.0.0"): + self.dialect._get_partitions(connection, "some_table", "some_schema") + + # 2.0 path consumes the result as a context manager and never falls + # back to manual fetch/close. + result.__enter__.assert_called_once() + result.__exit__.assert_called_once() + result.fetchall.assert_not_called() + result.close.assert_not_called() + + def test_get_partitions_fetches_and_closes_on_legacy_sqlalchemy(self): + partition_names = ["year"] + data_types = ["integer"] + connection, result = self._partitions_connection(partition_names, data_types) + + with mock.patch("trino.sqlalchemy.dialect.SQLALCHEMY_VERSION", "1.4.0"): + self.dialect._get_partitions(connection, "some_table", "some_schema") + + # Legacy path must drain the cursor so the query is FINISHED (not + # CANCELED) and then close it explicitly. + result.fetchall.assert_called_once() + result.close.assert_called_once() + result.__enter__.assert_not_called() + + def test_get_partitions_skips_close_when_already_closed_on_legacy(self): + partition_names = ["year"] + data_types = ["integer"] + connection, result = self._partitions_connection(partition_names, data_types) + result.closed = True + + with mock.patch("trino.sqlalchemy.dialect.SQLALCHEMY_VERSION", "1.4.0"): + self.dialect._get_partitions(connection, "some_table", "some_schema") + + result.close.assert_not_called() + + def test_get_partitions_defaults_schema_when_not_given(self): + partition_names = ["year"] + data_types = ["integer"] + connection, _ = self._partitions_connection(partition_names, data_types) + + with mock.patch.object( + self.dialect, "_get_default_schema_name", return_value="default_schema" + ) as get_default_schema: + self.dialect._get_partitions(connection, "some_table") + + get_default_schema.assert_called_once_with(connection) + # The defaulted schema must appear in the issued query. + query = str(connection.execute.call_args[0][0]) + assert "default_schema" in query + def test_trino_connection_basic_auth(): dialect = TrinoDialect() diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index ad05aeec..52dbc3bb 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -23,6 +23,7 @@ from sqlalchemy import exc from sqlalchemy import sql +from sqlalchemy import __version__ as SQLALCHEMY_VERSION from sqlalchemy.engine import Engine from sqlalchemy.engine.base import Connection from sqlalchemy.engine.default import DefaultDialect @@ -221,9 +222,20 @@ def _get_partitions( SELECT * FROM {schema}."{table_name}$partitions" """ ).strip() - res = connection.execute(sql.text(query)) - partition_names = [desc[0] for desc in res.cursor.description] - data_types = [desc[1] for desc in res.cursor.description] + + if SQLALCHEMY_VERSION >= '2.0.0': + # Lower versions of SQLAlchemy doesn't support execution of query as ctxt man + with connection.execute(sql.text(query)) as conn: + partition_names = [desc[0] for desc in conn.cursor.description] + data_types = [desc[1] for desc in conn.cursor.description] + else: + conn = connection.execute(sql.text(query)) + partition_names = [desc[0] for desc in conn.cursor.description] + data_types = [desc[1] for desc in conn.cursor.description] + res = conn.fetchall() # Fetching data to consider query as FINISHED and not CANCELED + if not conn.closed: + conn.close() + # Compare the column names and types to the shape of an Iceberg $partitions table if (partition_names == ['partition', 'record_count', 'file_count', 'total_size', 'data'] and data_types[0].startswith('row(')