Skip to content

Commit fa839b8

Browse files
ipamoSébastien Hocquet
andauthored
Refactor named parameters feature to allow extensibility (#148)
* Refactor access to parameters through an extensible class * Refactor form controls for named parameters * Add more extendable checks and features on the Parameter class * Fix findall return unpacking in case there is only one capturing group * Fix HTML title - Parameter name should be omitted only if there is only one parameter (not only one provided value) - Parameters are omitted if they equal their default value * Escape parameter form control inputs Co-authored-by: Sébastien Hocquet <shocquet@rtm.fr>
1 parent a399acb commit fa839b8

6 files changed

Lines changed: 111 additions & 65 deletions

File tree

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{% if parameters %}
2+
<h3>Query parameters</h3>
3+
<div class="query-parameters">
4+
{% for param in parameters %}
5+
{{ param.form_control }}
6+
{% endfor %}
7+
</div>
8+
<input
9+
class="btn"
10+
type="submit"
11+
value="Run quer{% if query_results|length > 1 %}ies{% else %}y{% endif %}"
12+
/>
13+
{% endif %}

django_sql_dashboard/templates/django_sql_dashboard/dashboard.html

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,7 @@ <h2 style="margin-top: 0.5em">Unverified SQL</h2>
2828
{% if query_results %}
2929
<p><a href="#save-dashboard">Save this dashboard</a> | <a href="{{ request.path }}">Remove all queries</a></p>
3030
{% endif %}
31-
{% if parameter_values %}
32-
<h3>Query parameters</h3>
33-
<div class="query-parameters">
34-
{% for name, value in parameter_values %}
35-
<label for="qp{{ forloop.counter }}">{{ name }}</label>
36-
<input type="text" id="qp{{ forloop.counter }}" name="{{ name }}" value="{{ value }}">
37-
{% endfor %}
38-
</div>
39-
<input
40-
class="btn"
41-
type="submit"
42-
value="Run quer{% if query_results|length > 1 %}ies{% else %}y{% endif %}"
43-
/>
44-
{% endif %}
31+
{% include "django_sql_dashboard/_parameters.html" %}
4532
{% for result in query_results %}
4633
{% include result.templates with result=result %}
4734
{% endfor %}

django_sql_dashboard/templates/django_sql_dashboard/saved_dashboard.html

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,7 @@ <h1>{% if dashboard.title %}{{ dashboard.title }}{% else %}{{ dashboard.slug }}{
3030
</p>
3131

3232
<form action="{{ request.path }}" method="GET">
33-
{% if parameter_values %}
34-
<h3>Query parameters</h3>
35-
<div class="query-parameters">
36-
{% for name, value in parameter_values %}
37-
<label for="qp{{ forloop.counter }}">{{ name }}</label>
38-
<input type="text" id="qp{{ forloop.counter }}" name="{{ name }}" value="{{ value }}">
39-
{% endfor %}
40-
</div>
41-
<input
42-
class="btn"
43-
type="submit"
44-
value="Run quer{% if query_results|length > 1 %}ies{% else %}y{% endif %}"
45-
/>
46-
{% endif %}
33+
{% include "django_sql_dashboard/_parameters.html" %}
4734
{% for result in query_results %}
4835
{% include result.templates with result=result %}
4936
{% endfor %}

django_sql_dashboard/utils.py

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from collections import namedtuple
66

77
from django.core import signing
8+
from django.conf import settings
9+
from django.utils.html import escape
10+
from django.utils.safestring import mark_safe
811

912
SQL_SALT = "django_sql_dashboard:query"
1013

@@ -58,20 +61,6 @@ def displayable_rows(rows):
5861
return fixed
5962

6063

61-
_named_parameters_re = re.compile(r"\%\(([^\)]+)\)s")
62-
63-
64-
def extract_named_parameters(sql):
65-
params = _named_parameters_re.findall(sql)
66-
# Validation step: after removing params, are there
67-
# any single `%` symbols that will confuse psycopg2?
68-
without_params = _named_parameters_re.sub("", sql)
69-
without_double_percents = without_params.replace("%%", "")
70-
if "%" in without_double_percents:
71-
raise ValueError(r"Found a single % character")
72-
return params
73-
74-
7564
def check_for_base64_upgrade(queries):
7665
if not queries:
7766
return
@@ -117,3 +106,79 @@ def apply_sort(sql, sort_column, is_desc=False):
117106
else:
118107
sql = "select * from ({}) as results".format(sql)
119108
return sql + ' order by "{}"{}'.format(sort_column, " desc" if is_desc else "")
109+
110+
111+
class Parameter:
112+
extract_re = re.compile(r"\%\(([^\)]+)\)s")
113+
114+
def __init__(self, name, default_value=""):
115+
self.name = name
116+
self.default_value = self.get_sanitized(default_value, for_default=True)
117+
118+
def ensure_consistency(self, previous):
119+
if self.name != previous.name:
120+
raise ValueError("Invalid name for parameter '%s': previously registered with name '%s'" % (self.name, previous.name))
121+
if self.default_value != "" and self.default_value != previous.default_value:
122+
raise ValueError("Invalid default value '%s' for parameter '%s': previously registered with default value '%s'" % (self.default_value, self.name, previous.default_value))
123+
124+
def get_sanitized(self, value, for_default=False):
125+
if value is None or (for_default and value == "null"):
126+
return None # represents DB null value
127+
if not isinstance(value, str):
128+
raise ValueError("Invalid %svalue for parameter '%s': '%s'" % ("default " if for_default else "", self.name, type(value).__name__))
129+
return value
130+
131+
@property
132+
def value(self):
133+
return self._value if hasattr(self, "_value") else self.default_value
134+
135+
@value.setter
136+
def value(self, new_value):
137+
self._value = self.get_sanitized(new_value) if new_value != "" else self.default_value
138+
139+
def form_control(self):
140+
return mark_safe(f"""<label for="qp_{escape(self.name)}">{escape(self.name)}</label>
141+
<input type="text" id="qp_{escape(self.name)}" name="{escape(self.name)}" value="{escape(self.value) if self.value is not None else ""}">""")
142+
143+
@classmethod
144+
def extract(cls, sql: str, value_sources: list[dict[str, str]], target: list=[]):
145+
for found in cls.extract_re.findall(sql):
146+
# Ensure 'found' is an iterable of capturing groups, even if there is only one capturing group in the regex
147+
if isinstance(found, str):
148+
found = [found]
149+
new_param = cls(*found)
150+
151+
# Ensure parameters are added only once
152+
previous_param = next((param for param in target if param.name == new_param.name), None)
153+
if previous_param:
154+
new_param.ensure_consistency(previous_param)
155+
else:
156+
target.append(new_param)
157+
158+
# Validation step: after removing params, are there
159+
# any single `%` symbols that will confuse psycopg2?
160+
without_params = cls.extract_re.sub("", sql)
161+
without_double_percents = without_params.replace("%%", "")
162+
if "%" in without_double_percents:
163+
raise ValueError(r"Found a single % character")
164+
165+
# Read values form sources
166+
for param in target:
167+
for source in value_sources:
168+
if param.name in source:
169+
param.value = source[param.name]
170+
break
171+
172+
return target
173+
174+
@classmethod
175+
def execute(cls, cursor, sql: str, parameters: list=[]):
176+
values = { param.name: param.value for param in parameters }
177+
cursor.execute(sql, values)
178+
179+
PARAMETER_CLASS = getattr(settings, "DASHBOARD_PARAMETER_CLASS", Parameter)
180+
if isinstance(PARAMETER_CLASS, str):
181+
from importlib import import_module
182+
183+
[module_name, class_name] = PARAMETER_CLASS.rsplit('.', 1)
184+
PARAMETER_CLASS = getattr(import_module(module_name), class_name)

django_sql_dashboard/views.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
apply_sort,
2626
check_for_base64_upgrade,
2727
displayable_rows,
28-
extract_named_parameters,
28+
PARAMETER_CLASS,
2929
postgresql_reserved_words,
3030
sign_sql,
3131
unsign_sql,
@@ -192,10 +192,7 @@ def _dashboard_index(
192192
sql_query_parameter_errors = []
193193
for sql in sql_queries:
194194
try:
195-
extracted = extract_named_parameters(sql)
196-
for p in extracted:
197-
if p not in parameters:
198-
parameters.append(p)
195+
PARAMETER_CLASS.extract(sql, value_sources=[request.POST, request.GET], target=parameters)
199196
sql_query_parameter_errors.append(False)
200197
except ValueError as e:
201198
if "%" in sql:
@@ -205,9 +202,9 @@ def _dashboard_index(
205202
else:
206203
sql_query_parameter_errors.append(str(e))
207204
parameter_values = {
208-
parameter: request.POST.get(parameter, request.GET.get(parameter, ""))
205+
parameter.name: parameter.value
209206
for parameter in parameters
210-
if parameter != "sql"
207+
if parameter.name != "sql"
211208
}
212209
extra_qs = "&{}".format(urlencode(parameter_values)) if parameter_values else ""
213210
results_index = -1
@@ -250,7 +247,7 @@ def _dashboard_index(
250247
# Running a SELECT prevents future SET TRANSACTION READ WRITE:
251248
cursor.execute("SELECT 1;")
252249
cursor.fetchall()
253-
cursor.execute(sql, parameter_values)
250+
PARAMETER_CLASS.execute(cursor, sql, parameters)
254251
try:
255252
rows = list(cursor.fetchmany(row_limit + 1))
256253
except ProgrammingError as e:
@@ -303,12 +300,12 @@ def _dashboard_index(
303300
if dashboard and dashboard.title:
304301
html_title = dashboard.title
305302

306-
# Add named parameter values, if any exist
303+
# Add named parameter values to the page title, when they are distinct from the default values
307304
provided_values = {
308-
key: value for key, value in parameter_values.items() if value.strip()
305+
param.name: param.value for param in parameters if param.value != param.default_value
309306
}
310307
if provided_values:
311-
if len(provided_values) == 1:
308+
if len(parameters) == 1:
312309
html_title += ": {}".format(list(provided_values.values())[0])
313310
else:
314311
html_title += ": {}".format(
@@ -343,7 +340,7 @@ def _dashboard_index(
343340
"user_can_execute_sql": user_can_execute_sql,
344341
"user_can_export_data": getattr(settings, "DASHBOARD_ENABLE_FULL_EXPORT", None)
345342
and user_can_execute_sql,
346-
"parameter_values": parameter_values.items(),
343+
"parameters": parameters,
347344
"too_long_so_use_post": too_long_so_use_post,
348345
"saved_dashboards": saved_dashboards,
349346
}
@@ -410,10 +407,7 @@ def export_sql_results(request):
410407
assert format in ("csv", "tsv")
411408
sqls = request.POST.getlist("sql")
412409
sql = sqls[int(sql_index)]
413-
parameter_values = {
414-
parameter: request.POST.get(parameter, "")
415-
for parameter in extract_named_parameters(sql)
416-
}
410+
parameters = PARAMETER_CLASS.extract(sql, value_sources=[request.POST])
417411
alias = getattr(settings, "DASHBOARD_DB_ALIAS", "dashboard")
418412
# Decide on filename
419413
sql_hash = hashlib.sha256(sql.encode("utf-8")).hexdigest()[:6]
@@ -443,7 +437,7 @@ def read_and_flush():
443437

444438
def rows():
445439
try:
446-
cursor.execute(sql, parameter_values)
440+
PARAMETER_CLASS.execute(cursor, sql, parameters)
447441
done_header = False
448442
while True:
449443
records = cursor.fetchmany(size=2000)

test_project/test_parameters.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ def test_parameter_form(admin_client, dashboard_db):
2626
html = response.content.decode("utf-8")
2727
# Form should have three form fields
2828
for fragment in (
29-
'<label for="qp1">foo</label>',
30-
'<input type="text" id="qp1" name="foo" value="">',
31-
'<label for="qp2">bar</label>',
32-
'<input type="text" id="qp2" name="bar" value="">',
33-
'<label for="qp3">baz</label>',
34-
'<input type="text" id="qp3" name="baz" value="">',
29+
'<label for="qp_foo">foo</label>',
30+
'<input type="text" id="qp_foo" name="foo" value="">',
31+
'<label for="qp_bar">bar</label>',
32+
'<input type="text" id="qp_bar" name="bar" value="">',
33+
'<label for="qp_baz">baz</label>',
34+
'<input type="text" id="qp_baz" name="baz" value="">',
3535
):
3636
assert fragment in html
3737

0 commit comments

Comments
 (0)