forked from googleapis/python-spanner
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbatch_dml_executor.py
More file actions
117 lines (95 loc) · 3.99 KB
/
batch_dml_executor.py
File metadata and controls
117 lines (95 loc) · 3.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, List
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
StatementType,
Statement,
)
from google.rpc.code_pb2 import ABORTED, OK
from google.api_core.exceptions import Aborted
from google.cloud.spanner_dbapi.transaction_helper import (
_get_batch_statements_result_checksum,
)
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets
if TYPE_CHECKING:
from google.cloud.spanner_dbapi.cursor import Cursor
class BatchDmlExecutor:
"""Executor that is used when a DML batch is started. These batches only
accept DML statements. All DML statements are buffered locally and sent to
Spanner when runBatch() is called.
:type "Cursor": :class:`~google.cloud.spanner_dbapi.cursor.Cursor`
:param cursor:
"""
def __init__(self, cursor: "Cursor"):
self._cursor = cursor
self._connection = cursor.connection
self._statements: List[Statement] = []
def execute_statement(self, parsed_statement: ParsedStatement):
"""Executes the statement when dml batch is active by buffering the
statement in-memory.
:type parsed_statement: ParsedStatement
:param parsed_statement: parsed statement containing sql query and query
params
"""
from google.cloud.spanner_dbapi import ProgrammingError
if (
parsed_statement.statement_type != StatementType.UPDATE
and parsed_statement.statement_type != StatementType.INSERT
):
raise ProgrammingError(
"Only DML statements are allowed in batch " "DML mode."
)
self._statements.append(parsed_statement.statement)
def run_batch_dml(self):
"""Executes all the buffered statements on the active dml batch by
making a call to Spanner.
"""
return run_batch_dml(self._cursor, self._statements)
def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
"""Executes all the dml statements by making a batch call to Spanner.
:type cursor: Cursor
:param cursor: Database Cursor object
:type statements: List[Statement]
:param statements: list of statements to execute in batch
"""
from google.cloud.spanner_dbapi import OperationalError
connection = cursor.connection
transaction_helper = connection._transaction_helper
many_result_set = StreamedManyResultSets()
statements_tuple = []
for statement in statements:
statements_tuple.append(statement.get_tuple())
if not connection._client_transaction_started:
res = connection.database.run_in_transaction(_do_batch_update, statements_tuple)
many_result_set.add_iter(res)
cursor._row_count = sum([max(val, 0) for val in res])
else:
while True:
try:
transaction = connection.transaction_checkout()
status, res = transaction.batch_update(statements_tuple)
if status.code == ABORTED:
connection._transaction = None
raise Aborted(status.message)
elif status.code != OK:
raise OperationalError(status.message)
checksum = _get_batch_statements_result_checksum(res, status.code)
many_result_set.add_iter(res)
transaction_helper._batch_statements_list.append((statements, checksum))
cursor._row_count = sum([max(val, 0) for val in res])
return many_result_set
except Aborted:
transaction_helper.retry_transaction()
def _do_batch_update(transaction, statements):
from google.cloud.spanner_dbapi import OperationalError
status, res = transaction.batch_update(statements)
if status.code == ABORTED:
raise Aborted(status.message)
elif status.code != OK:
raise OperationalError(status.message)
return res
class BatchMode(Enum):
DML = 1
DDL = 2
NONE = 3