|
5 | 5 | from collections import namedtuple |
6 | 6 |
|
7 | 7 | from django.core import signing |
8 | | -from django.conf import settings |
9 | | -from django.utils.html import escape |
10 | | -from django.utils.safestring import mark_safe |
11 | 8 |
|
12 | 9 | SQL_SALT = "django_sql_dashboard:query" |
13 | 10 |
|
@@ -61,6 +58,20 @@ def displayable_rows(rows): |
61 | 58 | return fixed |
62 | 59 |
|
63 | 60 |
|
| 61 | +_named_parameters_re = re.compile(r"\%\(([^\)]+)\)s") |
| 62 | + |
| 63 | + |
| 64 | +def extract_named_parameters(sql): |
| 65 | + params = _named_parameters_re.findall(sql) |
| 66 | + # Validation step: after removing params, are there |
| 67 | + # any single `%` symbols that will confuse psycopg2? |
| 68 | + without_params = _named_parameters_re.sub("", sql) |
| 69 | + without_double_percents = without_params.replace("%%", "") |
| 70 | + if "%" in without_double_percents: |
| 71 | + raise ValueError(r"Found a single % character") |
| 72 | + return params |
| 73 | + |
| 74 | + |
64 | 75 | def check_for_base64_upgrade(queries): |
65 | 76 | if not queries: |
66 | 77 | return |
@@ -106,79 +117,3 @@ def apply_sort(sql, sort_column, is_desc=False): |
106 | 117 | else: |
107 | 118 | sql = "select * from ({}) as results".format(sql) |
108 | 119 | return sql + ' order by "{}"{}'.format(sort_column, " desc" if is_desc else "") |
109 | | - |
110 | | - |
111 | | -class Parameter: |
112 | | - extract_re = re.compile(r"\%\(([^\)]+)\)s") |
113 | | - |
114 | | - def __init__(self, name, default_value=""): |
115 | | - self.name = name |
116 | | - self.default_value = self.get_sanitized(default_value, for_default=True) |
117 | | - |
118 | | - def ensure_consistency(self, previous): |
119 | | - if self.name != previous.name: |
120 | | - raise ValueError("Invalid name for parameter '%s': previously registered with name '%s'" % (self.name, previous.name)) |
121 | | - if self.default_value != "" and self.default_value != previous.default_value: |
122 | | - raise ValueError("Invalid default value '%s' for parameter '%s': previously registered with default value '%s'" % (self.default_value, self.name, previous.default_value)) |
123 | | - |
124 | | - def get_sanitized(self, value, for_default=False): |
125 | | - if value is None or (for_default and value == "null"): |
126 | | - return None # represents DB null value |
127 | | - if not isinstance(value, str): |
128 | | - raise ValueError("Invalid %svalue for parameter '%s': '%s'" % ("default " if for_default else "", self.name, type(value).__name__)) |
129 | | - return value |
130 | | - |
131 | | - @property |
132 | | - def value(self): |
133 | | - return self._value if hasattr(self, "_value") else self.default_value |
134 | | - |
135 | | - @value.setter |
136 | | - def value(self, new_value): |
137 | | - self._value = self.get_sanitized(new_value) if new_value != "" else self.default_value |
138 | | - |
139 | | - def form_control(self): |
140 | | - return mark_safe(f"""<label for="qp_{escape(self.name)}">{escape(self.name)}</label> |
141 | | -<input type="text" id="qp_{escape(self.name)}" name="{escape(self.name)}" value="{escape(self.value) if self.value is not None else ""}">""") |
142 | | - |
143 | | - @classmethod |
144 | | - def extract(cls, sql: str, value_sources: list[dict[str, str]], target: list=[]): |
145 | | - for found in cls.extract_re.findall(sql): |
146 | | - # Ensure 'found' is an iterable of capturing groups, even if there is only one capturing group in the regex |
147 | | - if isinstance(found, str): |
148 | | - found = [found] |
149 | | - new_param = cls(*found) |
150 | | - |
151 | | - # Ensure parameters are added only once |
152 | | - previous_param = next((param for param in target if param.name == new_param.name), None) |
153 | | - if previous_param: |
154 | | - new_param.ensure_consistency(previous_param) |
155 | | - else: |
156 | | - target.append(new_param) |
157 | | - |
158 | | - # Validation step: after removing params, are there |
159 | | - # any single `%` symbols that will confuse psycopg2? |
160 | | - without_params = cls.extract_re.sub("", sql) |
161 | | - without_double_percents = without_params.replace("%%", "") |
162 | | - if "%" in without_double_percents: |
163 | | - raise ValueError(r"Found a single % character") |
164 | | - |
165 | | - # Read values form sources |
166 | | - for param in target: |
167 | | - for source in value_sources: |
168 | | - if param.name in source: |
169 | | - param.value = source[param.name] |
170 | | - break |
171 | | - |
172 | | - return target |
173 | | - |
174 | | - @classmethod |
175 | | - def execute(cls, cursor, sql: str, parameters: list=[]): |
176 | | - values = { param.name: param.value for param in parameters } |
177 | | - cursor.execute(sql, values) |
178 | | - |
179 | | -PARAMETER_CLASS = getattr(settings, "DASHBOARD_PARAMETER_CLASS", Parameter) |
180 | | -if isinstance(PARAMETER_CLASS, str): |
181 | | - from importlib import import_module |
182 | | - |
183 | | - [module_name, class_name] = PARAMETER_CLASS.rsplit('.', 1) |
184 | | - PARAMETER_CLASS = getattr(import_module(module_name), class_name) |
0 commit comments