Skip to content

Commit ecba003

Browse files
authored
Merge pull request #181 from datacamp/feature-override
Feature override
2 parents c6cd747 + a57eab5 commit ecba003

5 files changed

Lines changed: 118 additions & 7 deletions

File tree

pythonwhat/check_funcs.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_not(*args, msg, state=None):
179179
return state
180180

181181
_msg = state.build_message(msg)
182-
return rep.do_test(Test(msg))
182+
return rep.do_test(Test(_msg))
183183

184184
# utility functions -----------------------------------------------------------
185185

@@ -194,10 +194,39 @@ def fail(msg="", state=None):
194194
"""Fail test with message"""
195195
rep = Reporter.active_reporter
196196
_msg = state.build_message(msg)
197-
rep.do_test(Test(Feedback(msg, state.highlight)))
197+
rep.do_test(Test(Feedback(_msg, state.highlight)))
198198

199199
return state
200200

201+
import ast
202+
def override(solution, state=None):
203+
"""Change the focused solution code."""
204+
205+
# the old ast may be a number of node types, but generally either a
206+
# (1) ast.Module, or for single expressions...
207+
# (2) whatever was grabbed using module.body[0]
208+
# (3) module.body[0].value, when module.body[0] is an Expr node
209+
old_ast = state.solution_tree
210+
new_ast = ast.parse(solution)
211+
if not isinstance(old_ast, ast.Module) and len(new_ast.body) == 1:
212+
expr = new_ast.body[0]
213+
candidates = [expr, expr.value] if isinstance(expr, ast.Expr) else [expr]
214+
for node in candidates:
215+
if isinstance(node, old_ast.__class__):
216+
new_ast = node
217+
break
218+
219+
kwargs = state.messages[-1] if state.messages else {}
220+
child = state.to_child_state(
221+
solution_subtree = new_ast,
222+
student_subtree = state.student_tree,
223+
highlight = state.highlight,
224+
append_message = {'msg': "", 'kwargs': kwargs}
225+
)
226+
227+
return child
228+
229+
201230
# context functions -----------------------------------------------------------
202231

203232
from pythonwhat.tasks import setUpNewEnvInProcess, breakDownNewEnvInProcess

pythonwhat/check_wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858

5959
for k in ['set_context',
6060
'has_equal_value', 'has_equal_output', 'has_equal_error', 'has_equal_ast', 'call',
61-
'extend', 'multi', 'test_not', 'fail', 'quiet',
61+
'extend', 'multi', 'test_not', 'fail', 'quiet', 'override',
6262
'with_context',
6363
'check_args',
6464
'has_equal_part']:

pythonwhat/parsing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,9 @@ def parse_node(cls, node):
608608
kwargs = cls.get_arg_part(node.args.kwarg, None, 'kwarg')
609609
all_args = [*args, varargs, *kw_args, kwargs]
610610

611+
if isinstance(node, ast.Lambda): body_node = node.body
612+
else: body_node = FunctionBodyTransformer().visit(ast.Module(node.body))
613+
611614
return {
612615
"node": node,
613616
"name": getattr(node, 'name', None),
@@ -616,7 +619,7 @@ def parse_node(cls, node):
616619
"_spec1_args": args,
617620
"*args": varargs,
618621
"**kwargs": kwargs,
619-
"body": {'node': FunctionBodyTransformer().visit(ast.Module(node.body)),
622+
"body": {'node': body_node,
620623
'target_vars': TargetVars(target_vars)}
621624
}
622625

tests/test_spec.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,76 @@ def test_fail(self):
173173
self.assertFalse(sct_payload['correct'])
174174

175175

176+
class TestOverride(unittest.TestCase):
177+
"""
178+
This class is used to test overriding w/ correct and incorrect code. Tests are
179+
run for entire nodes (e.g. an if block) and their parts (e.g. body of if block)
180+
"""
181+
182+
def do_exercise(self, code, base_check, parts, override=None, part_name = None, part_index = "", passes=True):
183+
"""High level function used to generate tests"""
184+
if part_name:
185+
if not override: override = parts[part_name]
186+
sct = base_check + '.check_{}({}).override("""{}""").has_equal_ast()'\
187+
.format(part_name, part_index, override)
188+
else:
189+
# whole code (e.g. if expression, or for loop)
190+
if not override: override = code.format(**parts)
191+
sct = base_check + '.override("""{}""").has_equal_ast()'.format(override)
192+
193+
data = {
194+
"DC_SOLUTION": code.format(**parts),
195+
"DC_CODE": code.format(**parts),
196+
"DC_SCT": sct
197+
}
198+
sct_payload = helper.run(data)
199+
self.assertTrue(sct_payload['correct']) if passes else self.assertFalse(sct_payload['correct'])
200+
201+
# used to generate tests
202+
EXPRESSIONS = {
203+
'if_exp': "{body} if {test} else {orelse}",
204+
'list_comp': "[{body} for i in {iter}]",
205+
'dict_comp': "{{ {key}: {value} for i in {iter} }}",
206+
'for_loop': "for i in {iter}: {body}",
207+
'while': "while {test}: {body}",
208+
'try_except': "try: {body}\nexcept: pass\nelse: {orelse}",
209+
'lambda_function': "lambda a={args}: {body}",
210+
'function_def': ["'sum'", "def sum(a={args}): {body}"],
211+
'function': ["'sum', 0", "sum({args})"]
212+
}
213+
214+
PARTS = {'body': "1", "test": "False", 'orelse': "2", 'iter': "range(3)",
215+
'key': "3", 'value': "4", 'args': "(1,2,3)"}
216+
217+
import re
218+
def gen_exercise(*args, **kwargs):
219+
return lambda self: TestOverride.do_exercise(self, *args, **kwargs)
220+
221+
for k, code in TestOverride.EXPRESSIONS.items():
222+
# base SCT, w/ special indexing if function checks
223+
if isinstance(code, list): indx, code = code
224+
else: indx = '0'
225+
base_check = "Ex().check_{}({})".format(k, indx)
226+
# pass overall test ----
227+
pf = gen_exercise(code, base_check, TestOverride.PARTS)
228+
setattr(TestOverride, 'test_{}_pass'.format(k), pf)
229+
# fail overall test ----
230+
pf = gen_exercise(code, base_check, TestOverride.PARTS, override="'WRONG ANSWER'", passes=False)
231+
setattr(TestOverride, 'test_{}_fail'.format(k), pf)
232+
# test individual pieces --------------------------------------------------
233+
for part in re.findall("\{([^{]*?)\}", code): # find all str.format vars, e.g. {body}
234+
part_index = "" if part != 'args' else 0
235+
# pass individual piece ----
236+
test_name = 'test_{}_{}_pass'.format(k, part)
237+
pf = gen_exercise(code, base_check, TestOverride.PARTS, part_name=part, part_index=part_index)
238+
setattr(TestOverride, test_name, pf)
239+
# fail individual piece ----
240+
test_name = 'test_{}_{}_fail'.format(k, part)
241+
bad_code = code.format(**{part: "[]", **TestOverride.PARTS})
242+
pf = gen_exercise(code, base_check, TestOverride.PARTS, part_name=part, part_index=part_index, override=bad_code, passes=False)
243+
setattr(TestOverride, test_name, pf)
244+
245+
176246
#class TestSetContext(unittest.TestCase):
177247
# def setUp(self):
178248
# self.data = {

tests/test_test_function.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -757,17 +757,26 @@ def test_pass_sig_false(self):
757757
sct_payload = helper.run(self.data)
758758
self.assertTrue(sct_payload['correct'])
759759

760-
@unittest.skip("TODO: implement override")
761760
def test_pass_sig_false_override(self):
762761
self.setup_color()
763-
self.data["DC_SCT"].replace('color', 'c')
762+
self.data["DC_CODE"] = self.data["DC_CODE"].replace('color', 'c')
764763
self.data['DC_SCT'] = """
765-
Ex().check_function('f', 0, signature=False).override("f(c = 'blue')").check_args('c').has_equal_ast()
764+
Ex().override("f(c = 'blue')").check_function('f', 0, signature=False).check_args('c').has_equal_ast()
766765
"""
767766

768767
sct_payload = helper.run(self.data)
769768
self.assertTrue(sct_payload['correct'])
770769

770+
@unittest.skip("TODO: override code isn't parsed, so can't get args part")
771+
def test_pass_sig_false_override_after_check(self):
772+
self.setup_color()
773+
self.data["DC_CODE"] = self.data["DC_CODE"].replace('color', 'c')
774+
self.data['DC_SCT'] = """
775+
Ex().check_function('f', 0, signature=False).override("f(c = 'blue')").check_args('c').has_equal_ast()
776+
"""
777+
sct_payload = helper.run(self.data)
778+
self.assertTrue(sct_payload['correct'])
779+
771780

772781
class TestFunctionComplexArgs(unittest.TestCase):
773782
def setUp(self):

0 commit comments

Comments
 (0)