-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathcheck_funcs.py
More file actions
360 lines (261 loc) · 12.5 KB
/
check_funcs.py
File metadata and controls
360 lines (261 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
from protowhat.Feedback import Feedback, FeedbackComponent
from protowhat.failure import debugger
from protowhat.sct_syntax import link_to_state
from protowhat.checks.check_simple import allow_errors
from sqlwhat.checks import has_no_error, has_result, has_equal_value
def allow_error(state):
"""Allow submission to pass, even if it originally caused a database error.
This function has been renamed.
Use ``Ex().allow_errors()`` in your SCT if the intent of the exercise is to
generate an error.
"""
allow_errors(state)
return state
def check_row(state, index, missing_msg=None, expand_msg=None):
"""Zoom in on a particular row in the query result, by index.
After zooming in on a row, which is represented as a single-row query result,
you can use ``has_equal_value()`` to verify whether all columns in the zoomed in solution
query result have a match in the student query result.
Args:
index: index of the row to zoom in on (zero-based indexed).
missing_msg: if specified, this overrides the automatically generated feedback
message in case the row is missing in the student query result.
expand_msg: if specified, this overrides the automatically generated feedback
message that is prepended to feedback messages that are thrown
further in the SCT chain.
:Example:
Suppose we are testing the following SELECT statements
* solution: ``SELECT artist_id as id, name FROM artists LIMIT 5``
* student : ``SELECT artist_id, name FROM artists LIMIT 2``
We can write the following SCTs: ::
# fails, since row 3 at index 2 is not in the student result
Ex().check_row(2)
# passes, since row 2 at index 1 is in the student result
Ex().check_row(0)
"""
if missing_msg is None:
missing_msg = "The system wants to verify row {{index + 1}} of your query result, but couldn't find it. Have another look."
if expand_msg is None:
expand_msg = "Have another look at row {{index + 1}} in your query result. "
msg_kwargs = {"index": index}
# check that query returned something
has_result(state)
stu_res = state.student_result
sol_res = state.solution_result
n_sol = len(next(iter(sol_res.values())))
n_stu = len(next(iter(stu_res.values())))
if index >= n_sol:
raise BaseException(
"There are only {} rows in the solution query result, and you're trying to fetch the row at index {}".format(
n_sol, index
)
)
if index >= n_stu:
state.report(missing_msg, msg_kwargs)
return state.to_child(
append_message=FeedbackComponent(expand_msg, msg_kwargs),
student_result={k: [v[index]] for k, v in stu_res.items()},
solution_result={k: [v[index]] for k, v in sol_res.items()},
)
def check_column(state, name, missing_msg=None, expand_msg=None):
"""Zoom in on a particular column in the query result, by name.
After zooming in on a column, which is represented as a single-column query result,
you can use ``has_equal_value()`` to verify whether the column in the solution query result
matches the column in student query result.
Args:
name: name of the column to zoom in on.
missing_msg: if specified, this overrides the automatically generated feedback
message in case the column is missing in the student query result.
expand_msg: if specified, this overrides the automatically generated feedback
message that is prepended to feedback messages that are thrown
further in the SCT chain.
:Example:
Suppose we are testing the following SELECT statements
* solution: ``SELECT artist_id as id, name FROM artists``
* student : ``SELECT artist_id, name FROM artists``
We can write the following SCTs: ::
# fails, since no column named id in student result
Ex().check_column('id')
# passes, since a column named name is in student_result
Ex().check_column('name')
"""
if missing_msg is None:
missing_msg = "We expected to find a column named `{{name}}` in the result of your query, but couldn't."
if expand_msg is None:
expand_msg = "Have another look at your query result. "
msg_kwargs = {"name": name}
# check that query returned something
has_result(state)
stu_res = state.student_result
sol_res = state.solution_result
if name not in sol_res:
raise BaseException("name %s not in solution column names" % name)
if name not in stu_res:
state.report(missing_msg, msg_kwargs)
return state.to_child(
append_message=FeedbackComponent(expand_msg, msg_kwargs),
student_result={name: stu_res[name]},
solution_result={name: sol_res[name]},
)
def check_all_columns(state, allow_extra=True, too_many_cols_msg=None, expand_msg=None):
"""Zoom in on the columns that are specified by the solution
Behind the scenes, this is using ``check_column()`` for every column that is in the solution query result.
Afterwards, it's selecting only these columns from the student query result and stores them in a child
state that is returned, so you can use ``has_equal_value()`` on it.
This function does not allow you to customize the messages for ``check_column()``. If you want to manually
set those, simply use ``check_column()`` explicitly.
Args:
allow_extra: True by default, this determines whether students are allowed to have included
other columns in their query result.
too_many_cols_msg: If specified, this overrides the automatically generated feedback message in
case ``allow_extra`` is False and the student's query returned extra columns when
comparing the so the solution query result.
expand_msg: if specified, this overrides the automatically generated feedback
message that is prepended to feedback messages that are thrown
further in the SCT chain.
:Example:
Consider the following solution and SCT: ::
# solution
SELECT artist_id as id, name FROM artists
# sct
Ex().check_all_columns()
# passing submission
SELECT artist_id as id, name FROM artists
# failing submission (wrong names)
SELECT artist_id, name FROM artists
# passing submission (allow_extra is True by default)
SELECT artist_id as id, name, label FROM artists
"""
if too_many_cols_msg is None:
too_many_cols_msg = (
"Your query result contains the column {{col}} but shouldn't."
)
if expand_msg is None:
expand_msg = "Have another look at your query result. "
child_stu_result = {}
child_sol_result = {}
for col in state.solution_result:
child = check_column(state, col)
child_stu_result.update(**child.student_result)
child_sol_result.update(**child.solution_result)
cols_not_in_sol = list(
set(state.student_result.keys()) - set(child_stu_result.keys())
)
if not allow_extra and len(cols_not_in_sol) > 0:
state.report(
"Your query result contains the column `{{col}}` but shouldn't.",
{"col": cols_not_in_sol[0]},
)
return state.to_child(
append_message=FeedbackComponent(expand_msg),
student_result=child_stu_result,
solution_result=child_sol_result,
)
def lowercase(state):
"""Convert all column names to their lower case versions to improve robustness
:Example:
Suppose we are testing the following SELECT statements
* solution: ``SELECT artist_id as id FROM artists``
* student : ``SELECT artist_id as ID FROM artists``
We can write the following SCTs: ::
# fails, as id and ID have different case
Ex().check_column('id').has_equal_value()
# passes, as lowercase() is being used
Ex().lowercase().check_column('id').has_equal_value()
"""
return state.to_child(
student_result={k.lower(): v for k, v in state.student_result.items()},
solution_result={k.lower(): v for k, v in state.solution_result.items()},
)
def check_result(state):
"""High level function which wraps other SCTs for checking results.
``check_result()``
* uses ``lowercase()``, then
* runs ``check_all_columns()`` on the state produced by ``lowercase()``, then
* runs ``has_equal_value`` on the state produced by ``check_all_columns()``.
"""
state1 = link_to_state(lowercase)(state)
state2 = link_to_state(check_all_columns)(state1)
has_equal_value(state2)
return state2
def check_query(state, query, error_msg=None, expand_msg=None):
"""Run arbitrary queries against to the DB connection to verify the database state.
For queries that do not return any output (INSERTs, UPDATEs, ...),
you cannot use functions like ``check_col()`` and ``has_equal_value()`` to verify the query result.
``check_query()`` will rerun the solution query in the transaction prepared by sqlbackend,
and immediately afterwards run the query specified in ``query``.
Next, it will also run this query after rerunning the student query in a transaction.
Finally, it produces a child state with these results, that you can then chain off of
with functions like ``check_column()`` and ``has_equal_value()``.
Args:
query: A SQL query as a string that is executed after the student query is re-executed.
error_msg: if specified, this overrides the automatically generated feedback
message in case the query generated an error.
expand_msg: if specified, this overrides the automatically generated feedback
message that is prepended to feedback messages that are thrown
further in the SCT chain.
:Example:
Suppose we are checking whether an INSERT happened correctly: ::
INSERT INTO company VALUES (2, 'filip', 28, 'sql-lane', 42)
We can write the following SCT: ::
Ex().check_query('SELECT COUNT(*) AS c FROM company').has_equal_value()
"""
if error_msg is None:
error_msg = "Running `{{query}}` after your submission generated an error."
if expand_msg is None:
expand_msg = "The autograder verified the result of running `{{query}}` against the database. "
msg_kwargs = {"query": query}
# before redoing the query,
# make sure that it didn't generate any errors
has_no_error(state)
# sqlbackend makes sure all queries are run in transactions.
# Rerun the solution code first, after which we run the provided query
with dbconn(state.solution_conn) as conn:
_ = runQuery(conn, state.solution_code)
sol_res = runQuery(conn, query)
if sol_res is None:
with debugger(state):
state.report("Solution failed: " + error_msg)
# sqlbackend makes sure all queries are run in transactions.
# Rerun the student code first, after wich we run the provided query
with dbconn(state.student_conn) as conn:
_ = runQuery(conn, state.student_code)
stu_res = runQuery(conn, query)
if stu_res is None:
state.report(error_msg, msg_kwargs)
return state.to_child(
append_message=FeedbackComponent(expand_msg, msg_kwargs),
student_result=stu_res,
solution_result=sol_res,
)
# The functions below are almost exact copies of what is in sqlbackend,
# But already use a conn object inside a transaction (this is how sqlbackend calls test_exercise)
from collections import OrderedDict
from contextlib import contextmanager
@contextmanager
def dbconn(conn):
sub_conn = conn.connect()
trans = sub_conn.begin()
yield sub_conn
try:
trans.rollback()
sub_conn.close()
except:
# we tried
pass
def runQuery(conn, code):
try:
if not code:
return OrderedDict()
cursor = conn.connection.cursor()
cursor.execute(str(code))
try:
records = [list(record) for record in cursor.fetchall()]
columns = [i[0] for i in cursor.description]
result = OrderedDict(zip(columns, list(zip(*records))))
except Exception:
result = None
cursor.close()
return result
except Exception:
return None