-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathsql_output_adapter.py
More file actions
108 lines (96 loc) · 3.7 KB
/
sql_output_adapter.py
File metadata and controls
108 lines (96 loc) · 3.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# -*- coding: utf-8 -*-
supported_formats = (
"sql-insert",
"sql-update",
"sql-update-1",
"sql-update-2",
)
preprocessors = ()
def escape_for_sql_statement(value):
if isinstance(value, bytes):
return f"X'{value.hex()}'"
else:
return "'{}'".format(value)
def adapter(data, headers, table_format=None, **kwargs):
"""
This function registers supported_formats to default TabularOutputFormatter
Parameters:
data: query result
headers: columns
table_format: values from supported_formats
kwargs:
extract_tables: extract_tables function. For example, in pgcli.packages.parseutils.tables there is a function extract_tables
delimiter: Character surrounds table name or column name when it conflicts with sql keywords.
For example, mysql uses ` and postgres uses "
"""
extract_table_func = kwargs.get("extract_tables")
if not extract_table_func:
raise ValueError("extract_tables function should be registered first")
tables = extract_table_func(formatter.query)
delimiter = kwargs.get("delimiter")
if not isinstance(delimiter, str):
delimiter = '"'
if tables is not None and len(tables) > 0:
table = tables[0]
if table[0]:
table_name = "{}.{}".format(*table[:2])
else:
table_name = table[1]
else:
table_name = "DUAL".format(delimiter=delimiter)
header_joiner = "{delimiter}, {delimiter}".format(delimiter=delimiter)
if table_format == "sql-insert":
h = header_joiner.join(headers)
yield "INSERT INTO {delimiter}{table_name}{delimiter} ({delimiter}{header}{delimiter}) VALUES".format(
table_name=table_name, header=h, delimiter=delimiter
)
prefix = " "
for d in data:
values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d))
yield "{}({})".format(prefix, values)
if prefix == " ":
prefix = ", "
yield ";"
if table_format.startswith("sql-update"):
s = table_format.split("-")
keys = 1
if len(s) > 2:
keys = int(s[-1])
for d in data:
yield "UPDATE {delimiter}{table_name}{delimiter} SET".format(
table_name=table_name, delimiter=delimiter
)
prefix = " "
for i, v in enumerate(d[keys:], keys):
yield "{prefix}{delimiter}{column}{delimiter} = {value}".format(
prefix=prefix,
delimiter=delimiter,
column=headers[i],
value=escape_for_sql_statement(v),
)
if prefix == " ":
prefix = ", "
f = "{delimiter}{column}{delimiter} = {value}"
where = (
f.format(
delimiter=delimiter,
column=headers[i],
value=escape_for_sql_statement(d[i]),
)
for i in range(keys)
)
yield "WHERE {};".format(" AND ".join(where))
def register_new_formatter(TabularOutputFormatter, **kwargs):
"""
Parameters:
TabularOutputFormatter: default TabularOutputFormatter imported from cli_helpers
kwargs: dict required, with key delimiter and tables required.
For example {"delimiter": "`", "extact_tables": extract_tables}
"""
global formatter
formatter = TabularOutputFormatter
for sql_format in supported_formats:
kwargs["table_format"] = sql_format
TabularOutputFormatter.register_new_formatter(
sql_format, adapter, preprocessors, kwargs
)