Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions tests/unit/test_executemany.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from decimal import Decimal

import pytest

import trino.client
from trino.dbapi import _INSERT_VALUES_RE
from trino.dbapi import Connection
from trino.dbapi import Cursor


class FakeTrinoQuery:
"""Hand-written fake replacing TrinoQuery for unit tests.

Records every SQL string passed and behaves like a successful INSERT.
"""

instances = []

def __init__(self, request, query, legacy_primitive_types=False, fetch_mode="mapped"):
self._query = query
self._update_type = "INSERT"
FakeTrinoQuery.instances.append(self)

def execute(self):
return iter([])

@property
def query(self):
return self._query

@property
def update_type(self):
return self._update_type


class FakeTrinoRequest:
pass


def _make_cursor(batch_executemany=True):
cur = Cursor.__new__(Cursor)
cur._connection = Connection.__new__(Connection)
cur._connection._client_session = type("cs", (), {"timezone": None})()
cur._request = FakeTrinoRequest()
cur._iterator = None
cur._query = None
cur._legacy_primitive_types = False
cur._experimental_batch_executemany = batch_executemany
return cur


@pytest.fixture
def cursor():
"""Create a real Cursor wired to FakeTrinoQuery with batching enabled.

Monkeypatches trino.client.TrinoQuery so the production
_executemany_batch_insert runs its real code path but creates
fakes instead of making HTTP calls.
"""
FakeTrinoQuery.instances = []
original = trino.client.TrinoQuery
trino.client.TrinoQuery = FakeTrinoQuery

yield _make_cursor(batch_executemany=True)

trino.client.TrinoQuery = original


class TestInsertValuesPattern:
def test_simple_insert(self):
assert _INSERT_VALUES_RE.match("INSERT INTO t (a, b) VALUES (?, ?)") is not None

def test_insert_with_schema(self):
sql = 'INSERT INTO "my_schema"."my_table" (col1, col2) VALUES (?, ?)'
assert _INSERT_VALUES_RE.match(sql) is not None

def test_insert_with_catalog_schema(self):
sql = 'INSERT INTO "catalog"."schema"."table" (a, b, c) VALUES (?, ?, ?)'
assert _INSERT_VALUES_RE.match(sql) is not None

def test_multiline_insert(self):
sql = ' INSERT INTO "schema"."table" (col1, col2)\n VALUES (?, ?)\n '
assert _INSERT_VALUES_RE.match(sql.strip()) is not None

def test_insert_no_columns(self):
assert _INSERT_VALUES_RE.match("INSERT INTO t VALUES (?, ?)") is not None

def test_select_not_matched(self):
assert _INSERT_VALUES_RE.match("SELECT * FROM t WHERE a = ?") is None

def test_update_not_matched(self):
assert _INSERT_VALUES_RE.match("UPDATE t SET a = ? WHERE b = ?") is None

def test_insert_select_not_matched(self):
assert _INSERT_VALUES_RE.match("INSERT INTO t SELECT * FROM s") is None

def test_trailing_semicolon_not_matched(self):
assert _INSERT_VALUES_RE.match("INSERT INTO t (a) VALUES (?);") is None

def test_case_insensitive(self):
assert _INSERT_VALUES_RE.match("insert into t values (?)") is not None

def test_prefix_extraction(self):
sql = 'INSERT INTO "s"."t" (a, b) VALUES (?, ?)'
m = _INSERT_VALUES_RE.match(sql)
assert m is not None
assert m.group(1).strip().endswith("VALUES")


class TestExecutemanyBatchInsert:
def test_single_row(self, cursor):
cursor.executemany(
"INSERT INTO t (a, b) VALUES (?, ?)",
[(1, "hello")]
)
assert len(FakeTrinoQuery.instances) == 1
assert "VALUES (1, 'hello')" in FakeTrinoQuery.instances[0].query

def test_multiple_rows_single_batch(self, cursor):
cursor.executemany(
"INSERT INTO t (a, b) VALUES (?, ?)",
[(1, "a"), (2, "b"), (3, "c")]
)
assert len(FakeTrinoQuery.instances) == 1
sql = FakeTrinoQuery.instances[0].query
assert "(1, 'a')" in sql
assert "(2, 'b')" in sql
assert "(3, 'c')" in sql

def test_chunking(self, cursor):
import trino.dbapi as _dbapi
original = _dbapi._EXECUTEMANY_BATCH_SIZE
_dbapi._EXECUTEMANY_BATCH_SIZE = 2
try:
cursor.executemany(
"INSERT INTO t (a) VALUES (?)",
[(1,), (2,), (3,), (4,), (5,)]
)
assert len(FakeTrinoQuery.instances) == 3
assert "(1)" in FakeTrinoQuery.instances[0].query
assert "(2)" in FakeTrinoQuery.instances[0].query
assert "(3)" in FakeTrinoQuery.instances[1].query
assert "(4)" in FakeTrinoQuery.instances[1].query
assert "(5)" in FakeTrinoQuery.instances[2].query
finally:
_dbapi._EXECUTEMANY_BATCH_SIZE = original

def test_null_values(self, cursor):
cursor.executemany(
"INSERT INTO t (a, b) VALUES (?, ?)",
[(1, None), (None, "test")]
)
sql = FakeTrinoQuery.instances[0].query
assert "(1, NULL)" in sql
assert "(NULL, 'test')" in sql

def test_mixed_types(self, cursor):
cursor.executemany(
"INSERT INTO t (a, b, c, d) VALUES (?, ?, ?, ?)",
[(42, "text", True, Decimal("3.14"))]
)
sql = FakeTrinoQuery.instances[0].query
assert "42" in sql
assert "'text'" in sql
assert "true" in sql
assert "DECIMAL '3.14'" in sql

def test_string_escaping(self, cursor):
cursor.executemany(
"INSERT INTO t (a) VALUES (?)",
[("it's a test",)]
)
assert "it''s a test" in FakeTrinoQuery.instances[0].query

def test_empty_params_does_not_batch(self, cursor):
cursor.executemany(
"INSERT INTO t (a) VALUES (?)",
[]
)
# Empty params takes the execute() path, not batch path.
# The FakeTrinoQuery from execute() still gets created but
# via the non-batch code path.
assert cursor._query is not None
assert cursor._query.query == "INSERT INTO t (a) VALUES (?)"


class TestBatchExecutemanyOptIn:
def test_disabled_by_default(self):
"""When experimental_batch_executemany is False (default), INSERT
executemany falls through to the row-by-row execute() path."""
FakeTrinoQuery.instances = []
original = trino.client.TrinoQuery
trino.client.TrinoQuery = FakeTrinoQuery
try:
cur = _make_cursor(batch_executemany=False)
conn = cur._connection
conn._use_legacy_prepared_statements = lambda: False
conn._create_request = lambda: FakeTrinoRequest()
cur.executemany(
"INSERT INTO t (a, b) VALUES (?, ?)",
[(1, "a"), (2, "b"), (3, "c")]
)
# Row-by-row: one TrinoQuery per param set via execute()
assert len(FakeTrinoQuery.instances) == 3
finally:
trino.client.TrinoQuery = original

def test_enabled_batches(self):
"""When experimental_batch_executemany is True, INSERT executemany
uses the batch path."""
FakeTrinoQuery.instances = []
original = trino.client.TrinoQuery
trino.client.TrinoQuery = FakeTrinoQuery
try:
cur = _make_cursor(batch_executemany=True)
cur.executemany(
"INSERT INTO t (a, b) VALUES (?, ?)",
[(1, "a"), (2, "b"), (3, "c")]
)
# Batched: single TrinoQuery with all rows
assert len(FakeTrinoQuery.instances) == 1
sql = FakeTrinoQuery.instances[0].query
assert "(1, 'a')" in sql
assert "(2, 'b')" in sql
assert "(3, 'c')" in sql
finally:
trino.client.TrinoQuery = original

def test_connection_threads_flag_to_cursor(self):
"""Connection.experimental_batch_executemany is passed to Cursor."""
conn = Connection.__new__(Connection)
conn.experimental_batch_executemany = True
conn.legacy_primitive_types = False
conn.legacy_prepared_statements = None
conn._isolation_level = 0 # AUTOCOMMIT
conn._transaction = None
cur = Cursor.__new__(Cursor)
cur._connection = conn
cur._request = FakeTrinoRequest()
cur._iterator = None
cur._query = None
cur._legacy_primitive_types = False
cur._experimental_batch_executemany = conn.experimental_batch_executemany
assert cur._experimental_batch_executemany is True
54 changes: 46 additions & 8 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""
import datetime
import math
import re
import uuid
from collections import OrderedDict
from decimal import Decimal
Expand Down Expand Up @@ -83,6 +84,13 @@

logger = trino.logging.get_logger(__name__)

_INSERT_VALUES_RE = re.compile(
r"\A(\s*INSERT\s+INTO\s+.+\bVALUES)\s*\([^)]+\)\s*\Z",
re.IGNORECASE | re.DOTALL,
)

_EXECUTEMANY_BATCH_SIZE = 100


class TimeBoundLRUCache:
"""A bounded LRU cache which expires entries after a configured number of seconds.
Expand Down Expand Up @@ -161,6 +169,7 @@ def __init__(
client_tags=None,
legacy_primitive_types=False,
legacy_prepared_statements=None,
experimental_batch_executemany=False,
roles=None,
timezone=None,
encoding: Union[str, List[str]] = _USE_DEFAULT_ENCODING,
Expand Down Expand Up @@ -237,6 +246,7 @@ def __init__(
self._transaction = None
self.legacy_primitive_types = legacy_primitive_types
self.legacy_prepared_statements = legacy_prepared_statements
self.experimental_batch_executemany = experimental_batch_executemany

@property
def isolation_level(self):
Expand Down Expand Up @@ -313,7 +323,8 @@ def cursor(self, cursor_style: str = "row", legacy_primitive_types: bool = None)
legacy_primitive_types
if legacy_primitive_types is not None
else self.legacy_primitive_types
)
),
experimental_batch_executemany=self.experimental_batch_executemany,
)

def _use_legacy_prepared_statements(self):
Expand Down Expand Up @@ -388,7 +399,8 @@ def __init__(
self,
connection,
request,
legacy_primitive_types: bool = False):
legacy_primitive_types: bool = False,
experimental_batch_executemany: bool = False):
if not isinstance(connection, Connection):
raise ValueError(
"connection must be a Connection object: {}".format(type(connection))
Expand All @@ -400,6 +412,7 @@ def __init__(
self._iterator = None
self._query = None
self._legacy_primitive_types = legacy_primitive_types
self._experimental_batch_executemany = experimental_batch_executemany

def __iter__(self):
return self._iterator
Expand Down Expand Up @@ -658,15 +671,38 @@ def executemany(self, operation, seq_of_params):

Return values are not defined.
"""
if not seq_of_params:
self.execute(operation)
return self

match = _INSERT_VALUES_RE.match(operation.strip())
if match and self._experimental_batch_executemany:
return self._executemany_batch_insert(match.group(1), seq_of_params)

for parameters in seq_of_params[:-1]:
self.execute(operation, parameters)
self.fetchall()
if self._query.update_type is None:
raise NotSupportedError("Query must return update type")
if seq_of_params:
self.execute(operation, seq_of_params[-1])
else:
self.execute(operation)
self.execute(operation, seq_of_params[-1])
return self

def _executemany_batch_insert(self, prefix, seq_of_params):
for i in range(0, len(seq_of_params), _EXECUTEMANY_BATCH_SIZE):
batch = seq_of_params[i:i + _EXECUTEMANY_BATCH_SIZE]
value_rows = []
for params in batch:
formatted = ", ".join(self._format_prepared_param(p) for p in params)
value_rows.append("(%s)" % formatted)
sql = "%s %s" % (prefix, ", ".join(value_rows))
self._query = trino.client.TrinoQuery(
self._request, query=sql,
legacy_primitive_types=self._legacy_primitive_types)
self._iterator = iter(self._query.execute())
if self._query.update_type is None:
raise NotSupportedError("Query must return update type")
if i + _EXECUTEMANY_BATCH_SIZE < len(seq_of_params):
self.fetchall()
return self

def fetchone(self) -> Optional[List[Any]]:
Expand Down Expand Up @@ -757,8 +793,10 @@ def __init__(
self,
connection,
request,
legacy_primitive_types: bool = False):
super().__init__(connection, request, legacy_primitive_types=legacy_primitive_types)
legacy_primitive_types: bool = False,
experimental_batch_executemany: bool = False):
super().__init__(connection, request, legacy_primitive_types=legacy_primitive_types,
experimental_batch_executemany=experimental_batch_executemany)
if self.connection._client_session.encoding is None:
raise ValueError("SegmentCursor can only be used if encoding is set on the connection")

Expand Down
4 changes: 4 additions & 0 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any
if "legacy_prepared_statements" in url.query:
kwargs["legacy_prepared_statements"] = json.loads(unquote_plus(url.query["legacy_prepared_statements"]))

if "experimental_batch_executemany" in url.query:
kwargs["experimental_batch_executemany"] = json.loads(
unquote_plus(url.query["experimental_batch_executemany"]))

if "verify" in url.query:
kwargs["verify"] = json.loads(unquote_plus(url.query["verify"]))

Expand Down
Loading