From 6f0c86aa9734dd8fddb44e5ffbce09df07988ea8 Mon Sep 17 00:00:00 2001 From: Sadha Chilukoori Date: Wed, 3 Jun 2026 08:17:37 -0700 Subject: [PATCH 1/2] Batch INSERT statements in executemany to reduce round trips Add opt-in experimental_batch_executemany connection parameter that batches INSERT...VALUES in executemany() into multi-row statements (100 rows per request) instead of one HTTP round trip per row. Default is False to preserve existing behavior. Batching changes partial failure semantics and rowcount reporting, so it is opt-in for now. --- tests/unit/test_executemany.py | 255 +++++++++++++++++++++++++++++++++ trino/dbapi.py | 54 +++++-- trino/sqlalchemy/dialect.py | 3 + 3 files changed, 304 insertions(+), 8 deletions(-) create mode 100644 tests/unit/test_executemany.py diff --git a/tests/unit/test_executemany.py b/tests/unit/test_executemany.py new file mode 100644 index 00000000..ea6fd5b4 --- /dev/null +++ b/tests/unit/test_executemany.py @@ -0,0 +1,255 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from decimal import Decimal + +import pytest + +import trino.client +from trino.dbapi import _INSERT_VALUES_RE +from trino.dbapi import Connection +from trino.dbapi import Cursor + + +class FakeTrinoQuery: + """Hand-written fake replacing TrinoQuery for unit tests. + + Records every SQL string passed and behaves like a successful INSERT. + """ + + instances = [] + + def __init__(self, request, query, legacy_primitive_types=False, fetch_mode="mapped"): + self._query = query + self._update_type = "INSERT" + FakeTrinoQuery.instances.append(self) + + def execute(self): + return iter([]) + + @property + def query(self): + return self._query + + @property + def update_type(self): + return self._update_type + + +class FakeTrinoRequest: + pass + + +def _make_cursor(batch_executemany=True): + cur = Cursor.__new__(Cursor) + cur._connection = Connection.__new__(Connection) + cur._connection._client_session = type("cs", (), {"timezone": None})() + cur._request = FakeTrinoRequest() + cur._iterator = None + cur._query = None + cur._legacy_primitive_types = False + cur._experimental_batch_executemany = batch_executemany + return cur + + +@pytest.fixture +def cursor(): + """Create a real Cursor wired to FakeTrinoQuery with batching enabled. + + Monkeypatches trino.client.TrinoQuery so the production + _executemany_batch_insert runs its real code path but creates + fakes instead of making HTTP calls. + """ + FakeTrinoQuery.instances = [] + original = trino.client.TrinoQuery + trino.client.TrinoQuery = FakeTrinoQuery + + yield _make_cursor(batch_executemany=True) + + trino.client.TrinoQuery = original + + +class TestInsertValuesPattern: + def test_simple_insert(self): + assert _INSERT_VALUES_RE.match("INSERT INTO t (a, b) VALUES (?, ?)") is not None + + def test_insert_with_schema(self): + sql = 'INSERT INTO "my_schema"."my_table" (col1, col2) VALUES (?, ?)' + assert _INSERT_VALUES_RE.match(sql) is not None + + def test_insert_with_catalog_schema(self): + sql = 'INSERT INTO "catalog"."schema"."table" (a, b, c) VALUES (?, ?, ?)' + assert _INSERT_VALUES_RE.match(sql) is not None + + def test_multiline_insert(self): + sql = ' INSERT INTO "schema"."table" (col1, col2)\n VALUES (?, ?)\n ' + assert _INSERT_VALUES_RE.match(sql.strip()) is not None + + def test_insert_no_columns(self): + assert _INSERT_VALUES_RE.match("INSERT INTO t VALUES (?, ?)") is not None + + def test_select_not_matched(self): + assert _INSERT_VALUES_RE.match("SELECT * FROM t WHERE a = ?") is None + + def test_update_not_matched(self): + assert _INSERT_VALUES_RE.match("UPDATE t SET a = ? WHERE b = ?") is None + + def test_insert_select_not_matched(self): + assert _INSERT_VALUES_RE.match("INSERT INTO t SELECT * FROM s") is None + + def test_trailing_semicolon_not_matched(self): + assert _INSERT_VALUES_RE.match("INSERT INTO t (a) VALUES (?);") is None + + def test_case_insensitive(self): + assert _INSERT_VALUES_RE.match("insert into t values (?)") is not None + + def test_prefix_extraction(self): + sql = 'INSERT INTO "s"."t" (a, b) VALUES (?, ?)' + m = _INSERT_VALUES_RE.match(sql) + assert m is not None + assert m.group(1).strip().endswith("VALUES") + + +class TestExecutemanyBatchInsert: + def test_single_row(self, cursor): + cursor.executemany( + "INSERT INTO t (a, b) VALUES (?, ?)", + [(1, "hello")] + ) + assert len(FakeTrinoQuery.instances) == 1 + assert "VALUES (1, 'hello')" in FakeTrinoQuery.instances[0].query + + def test_multiple_rows_single_batch(self, cursor): + cursor.executemany( + "INSERT INTO t (a, b) VALUES (?, ?)", + [(1, "a"), (2, "b"), (3, "c")] + ) + assert len(FakeTrinoQuery.instances) == 1 + sql = FakeTrinoQuery.instances[0].query + assert "(1, 'a')" in sql + assert "(2, 'b')" in sql + assert "(3, 'c')" in sql + + def test_chunking(self, cursor): + import trino.dbapi as _dbapi + original = _dbapi._EXECUTEMANY_BATCH_SIZE + _dbapi._EXECUTEMANY_BATCH_SIZE = 2 + try: + cursor.executemany( + "INSERT INTO t (a) VALUES (?)", + [(1,), (2,), (3,), (4,), (5,)] + ) + assert len(FakeTrinoQuery.instances) == 3 + assert "(1)" in FakeTrinoQuery.instances[0].query + assert "(2)" in FakeTrinoQuery.instances[0].query + assert "(3)" in FakeTrinoQuery.instances[1].query + assert "(4)" in FakeTrinoQuery.instances[1].query + assert "(5)" in FakeTrinoQuery.instances[2].query + finally: + _dbapi._EXECUTEMANY_BATCH_SIZE = original + + def test_null_values(self, cursor): + cursor.executemany( + "INSERT INTO t (a, b) VALUES (?, ?)", + [(1, None), (None, "test")] + ) + sql = FakeTrinoQuery.instances[0].query + assert "(1, NULL)" in sql + assert "(NULL, 'test')" in sql + + def test_mixed_types(self, cursor): + cursor.executemany( + "INSERT INTO t (a, b, c, d) VALUES (?, ?, ?, ?)", + [(42, "text", True, Decimal("3.14"))] + ) + sql = FakeTrinoQuery.instances[0].query + assert "42" in sql + assert "'text'" in sql + assert "true" in sql + assert "DECIMAL '3.14'" in sql + + def test_string_escaping(self, cursor): + cursor.executemany( + "INSERT INTO t (a) VALUES (?)", + [("it's a test",)] + ) + assert "it''s a test" in FakeTrinoQuery.instances[0].query + + def test_empty_params_does_not_batch(self, cursor): + cursor.executemany( + "INSERT INTO t (a) VALUES (?)", + [] + ) + # Empty params takes the execute() path, not batch path. + # The FakeTrinoQuery from execute() still gets created but + # via the non-batch code path. + assert cursor._query is not None + assert cursor._query.query == "INSERT INTO t (a) VALUES (?)" + + +class TestBatchExecutemanyOptIn: + def test_disabled_by_default(self): + """When experimental_batch_executemany is False (default), INSERT + executemany falls through to the row-by-row execute() path.""" + FakeTrinoQuery.instances = [] + original = trino.client.TrinoQuery + trino.client.TrinoQuery = FakeTrinoQuery + try: + cur = _make_cursor(batch_executemany=False) + conn = cur._connection + conn._use_legacy_prepared_statements = lambda: False + conn._create_request = lambda: FakeTrinoRequest() + cur.executemany( + "INSERT INTO t (a, b) VALUES (?, ?)", + [(1, "a"), (2, "b"), (3, "c")] + ) + # Row-by-row: one TrinoQuery per param set via execute() + assert len(FakeTrinoQuery.instances) == 3 + finally: + trino.client.TrinoQuery = original + + def test_enabled_batches(self): + """When experimental_batch_executemany is True, INSERT executemany + uses the batch path.""" + FakeTrinoQuery.instances = [] + original = trino.client.TrinoQuery + trino.client.TrinoQuery = FakeTrinoQuery + try: + cur = _make_cursor(batch_executemany=True) + cur.executemany( + "INSERT INTO t (a, b) VALUES (?, ?)", + [(1, "a"), (2, "b"), (3, "c")] + ) + # Batched: single TrinoQuery with all rows + assert len(FakeTrinoQuery.instances) == 1 + sql = FakeTrinoQuery.instances[0].query + assert "(1, 'a')" in sql + assert "(2, 'b')" in sql + assert "(3, 'c')" in sql + finally: + trino.client.TrinoQuery = original + + def test_connection_threads_flag_to_cursor(self): + """Connection.experimental_batch_executemany is passed to Cursor.""" + conn = Connection.__new__(Connection) + conn.experimental_batch_executemany = True + conn.legacy_primitive_types = False + conn.legacy_prepared_statements = None + conn._isolation_level = 0 # AUTOCOMMIT + conn._transaction = None + cur = Cursor.__new__(Cursor) + cur._connection = conn + cur._request = FakeTrinoRequest() + cur._iterator = None + cur._query = None + cur._legacy_primitive_types = False + cur._experimental_batch_executemany = conn.experimental_batch_executemany + assert cur._experimental_batch_executemany is True diff --git a/trino/dbapi.py b/trino/dbapi.py index 3749d797..9b6a4bb0 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -19,6 +19,7 @@ """ import datetime import math +import re import uuid from collections import OrderedDict from decimal import Decimal @@ -83,6 +84,13 @@ logger = trino.logging.get_logger(__name__) +_INSERT_VALUES_RE = re.compile( + r"\A(\s*INSERT\s+INTO\s+.+\bVALUES)\s*\([^)]+\)\s*\Z", + re.IGNORECASE | re.DOTALL, +) + +_EXECUTEMANY_BATCH_SIZE = 100 + class TimeBoundLRUCache: """A bounded LRU cache which expires entries after a configured number of seconds. @@ -161,6 +169,7 @@ def __init__( client_tags=None, legacy_primitive_types=False, legacy_prepared_statements=None, + experimental_batch_executemany=False, roles=None, timezone=None, encoding: Union[str, List[str]] = _USE_DEFAULT_ENCODING, @@ -237,6 +246,7 @@ def __init__( self._transaction = None self.legacy_primitive_types = legacy_primitive_types self.legacy_prepared_statements = legacy_prepared_statements + self.experimental_batch_executemany = experimental_batch_executemany @property def isolation_level(self): @@ -313,7 +323,8 @@ def cursor(self, cursor_style: str = "row", legacy_primitive_types: bool = None) legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types - ) + ), + experimental_batch_executemany=self.experimental_batch_executemany, ) def _use_legacy_prepared_statements(self): @@ -388,7 +399,8 @@ def __init__( self, connection, request, - legacy_primitive_types: bool = False): + legacy_primitive_types: bool = False, + experimental_batch_executemany: bool = False): if not isinstance(connection, Connection): raise ValueError( "connection must be a Connection object: {}".format(type(connection)) @@ -400,6 +412,7 @@ def __init__( self._iterator = None self._query = None self._legacy_primitive_types = legacy_primitive_types + self._experimental_batch_executemany = experimental_batch_executemany def __iter__(self): return self._iterator @@ -658,15 +671,38 @@ def executemany(self, operation, seq_of_params): Return values are not defined. """ + if not seq_of_params: + self.execute(operation) + return self + + match = _INSERT_VALUES_RE.match(operation.strip()) + if match and self._experimental_batch_executemany: + return self._executemany_batch_insert(match.group(1), seq_of_params) + for parameters in seq_of_params[:-1]: self.execute(operation, parameters) self.fetchall() if self._query.update_type is None: raise NotSupportedError("Query must return update type") - if seq_of_params: - self.execute(operation, seq_of_params[-1]) - else: - self.execute(operation) + self.execute(operation, seq_of_params[-1]) + return self + + def _executemany_batch_insert(self, prefix, seq_of_params): + for i in range(0, len(seq_of_params), _EXECUTEMANY_BATCH_SIZE): + batch = seq_of_params[i:i + _EXECUTEMANY_BATCH_SIZE] + value_rows = [] + for params in batch: + formatted = ", ".join(self._format_prepared_param(p) for p in params) + value_rows.append("(%s)" % formatted) + sql = "%s %s" % (prefix, ", ".join(value_rows)) + self._query = trino.client.TrinoQuery( + self._request, query=sql, + legacy_primitive_types=self._legacy_primitive_types) + self._iterator = iter(self._query.execute()) + if self._query.update_type is None: + raise NotSupportedError("Query must return update type") + if i + _EXECUTEMANY_BATCH_SIZE < len(seq_of_params): + self.fetchall() return self def fetchone(self) -> Optional[List[Any]]: @@ -757,8 +793,10 @@ def __init__( self, connection, request, - legacy_primitive_types: bool = False): - super().__init__(connection, request, legacy_primitive_types=legacy_primitive_types) + legacy_primitive_types: bool = False, + experimental_batch_executemany: bool = False): + super().__init__(connection, request, legacy_primitive_types=legacy_primitive_types, + experimental_batch_executemany=experimental_batch_executemany) if self.connection._client_session.encoding is None: raise ValueError("SegmentCursor can only be used if encoding is set on the connection") diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index ad05aeec..2b1388f6 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -169,6 +169,9 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any if "legacy_prepared_statements" in url.query: kwargs["legacy_prepared_statements"] = json.loads(unquote_plus(url.query["legacy_prepared_statements"])) + if "experimental_batch_executemany" in url.query: + kwargs["experimental_batch_executemany"] = json.loads(unquote_plus(url.query["experimental_batch_executemany"])) + if "verify" in url.query: kwargs["verify"] = json.loads(unquote_plus(url.query["verify"])) From d0fcd2d74cbe5c527d1b039e9ab298a995bf03f8 Mon Sep 17 00:00:00 2001 From: Sadha Chilukoori Date: Sun, 21 Jun 2026 07:40:50 -0700 Subject: [PATCH 2/2] Fix flake8 line-length violation in SQLAlchemy dialect --- trino/sqlalchemy/dialect.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index 2b1388f6..11d0735a 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -170,7 +170,8 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any kwargs["legacy_prepared_statements"] = json.loads(unquote_plus(url.query["legacy_prepared_statements"])) if "experimental_batch_executemany" in url.query: - kwargs["experimental_batch_executemany"] = json.loads(unquote_plus(url.query["experimental_batch_executemany"])) + kwargs["experimental_batch_executemany"] = json.loads( + unquote_plus(url.query["experimental_batch_executemany"])) if "verify" in url.query: kwargs["verify"] = json.loads(unquote_plus(url.query["verify"]))