diff --git a/tests/unit/sqlalchemy/test_compiler.py b/tests/unit/sqlalchemy/test_compiler.py index 07638a20..84f9c4cc 100644 --- a/tests/unit/sqlalchemy/test_compiler.py +++ b/tests/unit/sqlalchemy/test_compiler.py @@ -9,6 +9,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime + import pytest from sqlalchemy import Column from sqlalchemy import ForeignKey @@ -25,6 +27,7 @@ from sqlalchemy.sql import table from tests.unit.conftest import sqlalchemy_version +from trino.sqlalchemy.datatype import TIMESTAMP from trino.sqlalchemy.dialect import TrinoDialect metadata = MetaData() @@ -206,6 +209,24 @@ def test_try_cast(dialect): assert str(query) == 'SELECT try_cast("table".id as VARCHAR) AS id \nFROM "table"' +def test_timestamp_literal_processor(dialect): + ts_col = column("ts", TIMESTAMP()) + tbl = table("t", ts_col) + dt = datetime.datetime(2026, 6, 17, 9, 57, 43, 244000) + stmt = select(tbl).where(ts_col == dt) + query = stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True}) + assert "TIMESTAMP '2026-06-17 09:57:43.244'" in str(query) + + +def test_timestamp_literal_processor_no_microseconds(dialect): + ts_col = column("ts", TIMESTAMP()) + tbl = table("t", ts_col) + dt = datetime.datetime(2026, 6, 17, 9, 57, 43) + stmt = select(tbl).where(ts_col == dt) + query = stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True}) + assert "TIMESTAMP '2026-06-17 09:57:43'" in str(query) + + def test_catalogs_create_table_with_pk(dialect): with pytest.warns(SAWarning, match="Trino does not support PRIMARY KEY constraints. Constraint will be ignored."): statement = CreateTable(table_with_pk) diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py index f5ecf433..5f08ff93 100644 --- a/trino/sqlalchemy/datatype.py +++ b/trino/sqlalchemy/datatype.py @@ -9,6 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime import re from collections.abc import Iterator from typing import Any @@ -81,6 +82,17 @@ def __init__(self, precision=None, timezone=False): super(TIMESTAMP, self).__init__(timezone=timezone) self.precision = precision + def literal_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + ts = value.strftime("%Y-%m-%d %H:%M:%S") + if value.microsecond: + ts += f".{value.microsecond // 1000:03d}" + return f"TIMESTAMP '{ts}'" + return str(value) + + return process + class JSON(TypeDecorator): impl = JSON