Skip to content

Commit 25a1f67

Browse files
Make Db2 dialect statement-cache compatible (#188)
Signed-off-by: Balram Choudhary <bchoudhary@rocketsoftware.com>
1 parent b593008 commit 25a1f67

6 files changed

Lines changed: 98 additions & 41 deletions

File tree

ibm_db_sa/base.py

Lines changed: 92 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from sqlalchemy import types as sa_types
2525
from sqlalchemy import schema as sa_schema
2626
from sqlalchemy import util
27+
from sqlalchemy import exc
28+
from sqlalchemy.sql.elements import BindParameter
2729
from sqlalchemy.sql import compiler
2830
from sqlalchemy.sql import operators
2931
from sqlalchemy.engine import default
@@ -396,42 +398,87 @@ def visit_mod_binary(self, binary, operator, **kw):
396398
return "mod(%s, %s)" % (self.process(binary.left),
397399
self.process(binary.right))
398400

399-
def limit_clause(self, select, **kwargs):
400-
limit = select._limit
401-
offset = select._offset or 0
402-
403-
if limit is not None:
404-
if offset > 0:
405-
return f" LIMIT {limit} OFFSET {offset}"
406-
else:
407-
return f" LIMIT {limit}"
408-
return ""
401+
def literalBindsFlagFrom_kw(self, kw=None):
402+
"""Return True if literal_binds is requested in compile kwargs."""
403+
if not kw or not isinstance(kw, dict):
404+
return False
405+
if kw.get("literal_binds"):
406+
return True
407+
ck = kw.get("compile_kwargs")
408+
if isinstance(ck, dict) and ck.get("literal_binds"):
409+
return True
410+
return False
411+
412+
def limit_clause(self, select, **kw):
413+
text = ""
414+
limit_clause = select._limit_clause
415+
offset_clause = select._offset_clause
416+
literal_binds = self.literalBindsFlagFrom_kw(kw)
409417

410-
def visit_select(self, select, **kwargs):
411-
limit, offset = select._limit, select._offset
412-
sql_ori = compiler.SQLCompiler.visit_select(self, select, **kwargs)
418+
def _render_clause(clause):
419+
if clause is None:
420+
return None
421+
if select._simple_int_clause(clause):
422+
return self.process(clause.render_literal_execute(), **kw)
423+
if literal_binds:
424+
if hasattr(clause, "render_literal_execute"):
425+
try:
426+
return self.process(clause.render_literal_execute(), **kw)
427+
except Exception:
428+
pass
429+
try:
430+
return self.process(clause, literal_binds=True, **kw)
431+
except Exception:
432+
pass
433+
try:
434+
if isinstance(clause, BindParameter):
435+
val = getattr(clause, "value", None)
436+
if val is not None:
437+
if isinstance(val, str):
438+
return f"'{val}'"
439+
return str(val)
440+
except Exception:
441+
pass
442+
try:
443+
return self.process(clause, **kw)
444+
except Exception as e:
445+
raise exc.CompileError(
446+
"dialect 'ibm_db_sa' cannot render LIMIT/OFFSET for this clause; "
447+
"ensure the clause is a simple integer or is processable by the compiler."
448+
) from e
449+
450+
limit_text = _render_clause(limit_clause)
451+
if limit_text is not None:
452+
text += " LIMIT %s" % limit_text
453+
offset_text = _render_clause(offset_clause)
454+
if offset_text is not None:
455+
text += " OFFSET %s" % offset_text
456+
return text
413457

414-
if ('LIMIT' in sql_ori.upper()) or ('FETCH FIRST' in sql_ori.upper()):
458+
def visit_select(self, select, **kw):
459+
sql_ori = compiler.SQLCompiler.visit_select(self, select, **kw)
460+
if ("LIMIT" in sql_ori.upper()) or ("FETCH FIRST" in sql_ori.upper()):
415461
return sql_ori
416-
417-
if limit is not None:
418-
sql = re.sub(r'FETCH FIRST \d+ ROWS ONLY', '', sql_ori, flags=re.IGNORECASE).strip()
419-
limit_offset_clause = self.limit_clause(select, **kwargs)
420-
sql += limit_offset_clause
421-
return sql
422-
423-
if offset is not None:
462+
limit_clause_obj = select._limit_clause
463+
offset_clause_obj = select._offset_clause
464+
if limit_clause_obj is not None:
465+
limit_offset_clause = self.limit_clause(select, **kw)
466+
if limit_offset_clause:
467+
return sql_ori + limit_offset_clause
468+
if offset_clause_obj is not None:
424469
__rownum = 'Z.__ROWNUM'
425-
sql_split = re.split(r"[\s+]FROM ", sql_ori, 1)
470+
sql_work = re.sub(r'FETCH FIRST \d+ ROWS ONLY', '', sql_ori, flags=re.IGNORECASE).strip()
471+
sql_work = re.sub(r'\s+OFFSET\s+(?:\d+|__\[POSTCOMPILE_[^\]]+\]|:[A-Za-z0-9_]+|\?)\s*$', '', sql_work,
472+
flags=re.IGNORECASE)
473+
sql_split = re.split(r"[\s+]FROM ", sql_work, 1)
474+
if len(sql_split) < 2:
475+
return sql_ori
426476
sql_sec = " \nFROM %s " % (sql_split[1])
427-
428477
dummyVal = "Z.__db2_"
429478
sql_pri = ""
430-
431479
sql_sel = "SELECT "
432480
if select._distinct:
433481
sql_sel = "SELECT DISTINCT "
434-
435482
sql_select_token = sql_split[0].split(",")
436483
i = 0
437484
while i < len(sql_select_token):
@@ -440,32 +487,41 @@ def visit_select(self, select, **kwargs):
440487
sql_pri = f'{sql_pri} {sql_select_token[i]},{sql_select_token[i + 1]},{sql_select_token[i + 2]},{sql_select_token[i + 3]} AS "{dummyVal}{i + 1}",'
441488
i += 4
442489
continue
443-
444490
if sql_select_token[i].count(" AS ") == 1:
445491
temp_col_alias = sql_select_token[i].split(" AS ")
446492
sql_pri = f'{sql_pri} {sql_select_token[i]},'
447493
sql_sel = f'{sql_sel} {temp_col_alias[1]},'
448494
i += 1
449495
continue
450-
451496
sql_pri = f'{sql_pri} {sql_select_token[i]} AS "{dummyVal}{i + 1}",'
452497
sql_sel = f'{sql_sel} "{dummyVal}{i + 1}",'
453498
i += 1
454-
455499
sql_pri = sql_pri.rstrip(",")
456500
sql_pri = f"{sql_pri}{sql_sec}"
457501
sql_sel = sql_sel.rstrip(",")
458502
sql = f'{sql_sel}, ( ROW_NUMBER() OVER() ) AS "{__rownum}" FROM ( {sql_pri} ) AS M'
459503
sql = f'{sql_sel} FROM ( {sql} ) Z WHERE'
460504

461-
if offset != 0:
462-
sql = f'{sql} "{__rownum}" > {offset}'
463-
if offset != 0 and limit is not None:
505+
def _process_clause_text(clause):
506+
if clause is None:
507+
return None
508+
if select._simple_int_clause(clause):
509+
return self.process(clause.render_literal_execute(), **kw)
510+
else:
511+
return self.process(clause, **kw)
512+
513+
offset_text = _process_clause_text(offset_clause_obj)
514+
limit_text = _process_clause_text(limit_clause_obj)
515+
if offset_text is not None:
516+
sql = f'{sql} "{__rownum}" > {offset_text}'
517+
if offset_text is not None and limit_text is not None:
464518
sql = f'{sql} AND '
465-
if limit is not None:
466-
sql = f'{sql} "{__rownum}" <= {offset + limit}'
519+
if limit_text is not None:
520+
if offset_text is not None:
521+
sql = f'{sql} "{__rownum}" <= ({offset_text} + {limit_text})'
522+
else:
523+
sql = f'{sql} "{__rownum}" <= {limit_text}'
467524
return f"( {sql} )"
468-
469525
return sql_ori
470526

471527
def visit_sequence(self, sequence, **kw):
@@ -753,7 +809,7 @@ class DB2Dialect(default.DefaultDialect):
753809
supports_sane_multi_rowcount = True
754810
supports_native_decimal = False
755811
supports_native_boolean = False
756-
supports_statement_cache = False
812+
supports_statement_cache = True
757813
preexecute_sequences = False
758814
supports_alter = True
759815
supports_sequences = True

ibm_db_sa/ibm_db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def get_result_proxy(self):
9999
class DB2Dialect_ibm_db(DB2Dialect):
100100
driver = 'ibm_db_sa'
101101
supports_unicode_statements = True
102-
supports_statement_cache = False
102+
supports_statement_cache = True
103103
supports_sane_rowcount = True
104104
supports_sane_multi_rowcount = False
105105
supports_native_decimal = False

ibm_db_sa/pyodbc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class DB2Dialect_pyodbc(PyODBCConnector, DB2Dialect):
3030
supports_unicode_statements = True
3131
supports_char_length = True
3232
supports_native_decimal = False
33-
supports_statement_cache = False
33+
supports_statement_cache = True
3434

3535
execution_ctx_cls = DB2ExecutionContext_pyodbc
3636

ibm_db_sa/reflection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
class CoerceUnicode(sa_types.TypeDecorator):
3131
impl = sa_types.Unicode
32+
cache_ok = True
3233

3334
def process_bind_param(self, value, dialect):
3435
if isinstance(value, str):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[build-system]
2-
requires = ["setuptools>=42", "wheel"]
2+
requires = ["setuptools>=42", "wheel", "packaging>=20.0"]
33
build-backend = "setuptools.build_meta"

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111

1212
readme = os.path.join(os.path.dirname(__file__), 'README.md')
1313
if 'USE_PYODBC' in os.environ and os.environ['USE_PYODBC'] == '1':
14-
require = ['sqlalchemy>=0.7.3']
14+
require = ['sqlalchemy>=0.7.3', 'packaging>=20.0']
1515
else:
16-
require = ['sqlalchemy>=0.7.3','ibm_db>=2.0.0']
16+
require = ['sqlalchemy>=0.7.3','ibm_db>=2.0.0', 'packaging>=20.0']
1717

1818

1919
setup(

0 commit comments

Comments
 (0)