Skip to content

Commit ef3cf4a

Browse files
authored
Merge pull request #180 from datacamp/feature-check-function
Feature check function
2 parents c650dd4 + c4ae606 commit ef3cf4a

6 files changed

Lines changed: 206 additions & 36 deletions

File tree

pythonwhat/check_funcs.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def check_node(name, index, typestr, missing_msg=MSG_MISSING, expand_msg=MSG_PRE
7474

7575
# check if there are enough nodes for index
7676
fmt_kwargs = {'ordinal': get_ord(index+1) if isinstance(index, int) else "",
77-
'index': index}
77+
'index': index,
78+
'name': name}
7879
fmt_kwargs['typestr'] = typestr.format(**fmt_kwargs)
7980

8081
# test if node can be indexed succesfully
@@ -267,7 +268,8 @@ def check_args(name, missing_msg='FMT:Are you sure it is defined?', state=None):
267268
if name in ['*args', '**kwargs']:
268269
return check_part(name, name, state=state, missing_msg = missing_msg)
269270
else:
270-
return check_part_index('args', name, "argument `%s`"%name, state=state, missing_msg = missing_msg)
271+
arg_str = "%s argument"%get_ord(name+1) if isinstance(name, int) else "argument `%s`"%name
272+
return check_part_index('args', name, arg_str, state=state, missing_msg = missing_msg)
271273

272274

273275
# CALL CHECK ==================================================================
@@ -369,6 +371,18 @@ def call(args,
369371
# Expression tests ------------------------------------------------------------
370372
from pythonwhat.tasks import ReprFail, UndefinedValue
371373
from pythonwhat import utils
374+
375+
def has_equal_ast(incorrect_msg="FMT: Your code does not seem to match the solution.", state=None):
376+
rep = Reporter.active_reporter
377+
378+
stu_rep = ast.dump(state.student_tree)
379+
sol_rep = ast.dump(state.solution_tree)
380+
381+
_msg = state.build_message(incorrect_msg)
382+
rep.do_test(EqualTest(stu_rep, sol_rep, Feedback(_msg, state.highlight)))
383+
384+
return state
385+
372386
def has_expr(incorrect_msg="FMT:Unexpected expression {test}: expected `{sol_eval}`, got `{stu_eval}` with values{extra_env}.",
373387
error_msg="Running an expression in the student process caused an issue.",
374388
undefined_msg="FMT:Have you defined `{name}` without errors?",

pythonwhat/check_function.py

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,72 @@
1-
from pythonwhat.check_funcs import check_node
2-
from pythonwhat.test_funcs.test_function import mapped_name
1+
from pythonwhat.Reporter import Reporter
2+
from pythonwhat.check_funcs import part_to_child
3+
from pythonwhat.test_funcs.test_function import bind_args
34
from pythonwhat.tasks import getSignatureInProcess
5+
from pythonwhat.utils import get_ord
6+
from pythonwhat.Test import Test
7+
from pythonwhat.Feedback import Feedback
8+
from pythonwhat.parsing import IndexedDict
49
from functools import partial
510

6-
def check_function(name, index=0,
7-
missing_msg = "Did you define {sol_part[name]}?",
8-
expand_msg = "In your definition of {sol_part[name]}, ",
11+
def bind_args(signature, args_part):
12+
pos_args = []; kw_args = {}
13+
for k, arg in args_part.items():
14+
if isinstance(k, int): pos_args.append(arg)
15+
else: kw_args[k] = arg
16+
17+
bound_args = signature.bind(*pos_args, **kw_args)
18+
19+
return (IndexedDict(bound_args.arguments), signature)
20+
21+
MSG_PREPEND = "__JINJA__:Check your code in the {{child['part']+ ' of the' if child['part']}} {{typestr}}. "
22+
def check_function(name, index,
23+
missing_msg = "FMT:Did you define {typestr}?",
24+
params_not_matched_msg = "FMT:Something went wrong in figuring out how you specified the "
25+
"arguments for `{name}`; have another look at your code and its output.",
26+
expand_msg = MSG_PREPEND,
27+
signature=None,
28+
typestr = "{ordinal} function call",
929
state=None):
1030
rep = Reporter.active_reporter
1131
stu_out = state.student_function_calls
1232
sol_out = state.solution_function_calls
1333

14-
# test if function exists
15-
stud_name = get_mapped_name(name, state.student_mappings)
16-
17-
func_list = check_node('function_calls', name, 'function call', missing_msg, expand_msg, state)
18-
# get function state
19-
if index is None:
20-
return func_list
21-
else:
22-
# TODO make has_part more robust
23-
# grab specific function call
24-
child_func = check_part(index, "FUNCTION MSG", func_list, "not enough func calls")
25-
stu_parts, sol_parts = child_func.student_parts, child_func.solution_parts
26-
# Signatures
34+
fmt_kwargs = {'ordinal': get_ord(index+1),
35+
'index': index,
36+
'name': name}
37+
fmt_kwargs['typestr'] = typestr.format(**fmt_kwargs)
38+
39+
# Get Parts ----
40+
try:
41+
stu_parts = stu_out[name][index]
42+
except (KeyError, IndexError):
43+
_msg = state.build_message(missing_msg, fmt_kwargs)
44+
rep.do_test(Test(Feedback(_msg, state.highlight)))
45+
46+
sol_parts = sol_out[name][index]
47+
48+
# Signatures -----
49+
if signature:
50+
signature = None if isinstance(signature, bool) else signature
2751
get_sig = partial(getSignatureInProcess, name=name, signature=signature,
28-
manual_sigs = state.get_manual_sigs())
52+
manual_sigs = state.get_manual_sigs())
2953

30-
# TODO if can't parse, raise warnings
31-
sol_sig = get_sig(mapped_name=sol_parts['name'], process=solution_process)
32-
sol_parts['args'], _ = bind_ards(sol_sig, sol_parts['pos_args'], sol_parts['keywords'])
54+
try:
55+
sol_sig = get_sig(mapped_name=sol_parts['name'], process=state.solution_process)
56+
sol_parts['args'], _ = bind_args(sol_sig, sol_parts['args'])
57+
except:
58+
raise ValueError("Something went wrong in matching call index {index} of {name} to its signature. "
59+
"You might have to manually specify or correct the signature."
60+
.format(index=index, name=name))
3361

34-
# TODO if can't parse sig, send failed test msg
35-
stu_sig = get_sig(mapped_name=stu_parts['name'], process=student_process)
36-
stu_parts['args'], _ = bind_ards(stu_sig, stu_parts['pos_args'], stu_parts['keywords'])
62+
try:
63+
stu_sig = get_sig(mapped_name=stu_parts['name'], process=state.student_process)
64+
stu_parts['args'], _ = bind_args(stu_sig, stu_parts['args'])
65+
except Exception as e:
66+
_msg = state.build_message(params_not_matched_msg, fmt_kwargs)
67+
rep.do_test(Test(Feedback(_msg, state.highlight)))
3768

38-
# three types of parts: pos_args, keywords, args (e.g. these are bound to sig)
39-
return child_func
69+
# three types of parts: pos_args, keywords, args (e.g. these are bound to sig)
70+
append_message = {'msg': expand_msg, 'kwargs': fmt_kwargs}
71+
child = part_to_child(stu_parts, sol_parts, append_message, state, node_name='function_calls')
72+
return child

pythonwhat/check_wrappers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pythonwhat.check_funcs import check_part, check_part_index, check_node, has_equal_part
22
from pythonwhat import check_funcs, check_object
3+
from pythonwhat.check_function import check_function
34
from pythonwhat.test_funcs.test_data_frame import check_df
45
from pythonwhat.test_funcs.test_dictionary import check_dict
56
from pythonwhat import test_funcs
@@ -53,9 +54,10 @@
5354

5455
for k, v in __NODE_WRAPPERS__.items():
5556
scts['check_'+k] = partial(check_node, k+'s', typestr=v)
57+
scts['check_function'] = check_function
5658

5759
for k in ['set_context',
58-
'has_equal_value', 'has_equal_output', 'has_equal_error', 'call',
60+
'has_equal_value', 'has_equal_output', 'has_equal_error', 'has_equal_ast', 'call',
5961
'extend', 'multi', 'test_not', 'fail', 'quiet',
6062
'with_context',
6163
'check_args',

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
setup(
66
name='pythonwhat',
7-
version='2.1.2',
7+
version='2.2.0',
88
packages=['pythonwhat', 'pythonwhat.test_funcs'],
99
install_requires=["dill", "IPython", "numpy", "pandas", "markdown2"]
1010
)

tests/test_test_function.py

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ def test_Pass(self):
3434
self.assertTrue(sct_payload['correct'])
3535
self.assertEqual(sct_payload['message'], "Great!")
3636

37+
def test_Pass_spec2(self):
38+
self.data['DC_SCT'] = """
39+
Ex().check_function('print', 0).check_args(0).has_equal_ast()
40+
"""
41+
sct_payload = helper.run(self.data)
42+
self.assertTrue(sct_payload['correct'])
43+
3744
class TestFunctionExerciseNumpy(unittest.TestCase):
3845

3946
def setUp(self):
@@ -599,9 +606,7 @@ def setUp(self):
599606
test_function("print", index = 3, highlight=True)
600607
'''
601608
}
602-
self.DC_SCT_SPEC2 = '''
603-
Ex().check_function("print", 0).check_arg(0).has_equal_value()
604-
'''
609+
605610
def test_multiple_1(self):
606611
self.data["DC_CODE"] = 'print("abc")'
607612
sct_payload = helper.run(self.data)
@@ -659,6 +664,110 @@ def test_nohighlight_too_few_calls(self):
659664
self.assertFalse(sct_payload['correct'])
660665
self.assertEqual(sct_payload.get('line_start'), None)
661666

667+
class TestCheckFunction(unittest.TestCase):
668+
def setUp(self):
669+
self.data = {
670+
"DC_PEC": "import numpy as np",
671+
"DC_CODE": "np.array([1,2,3])",
672+
"DC_SOLUTION": "np.array([1,2,3])",
673+
"DC_SCT": "Ex().check_function('numpy.array', 0)"
674+
}
675+
676+
def run_append(self, sct):
677+
self.data["DC_SCT"] += sct
678+
return helper.run(self.data)
679+
680+
def run_pass(self, sct):
681+
sct_payload = self.run_append(sct)
682+
print(sct_payload)
683+
self.assertTrue(sct_payload['correct'])
684+
return sct_payload
685+
686+
def run_fail(self, sct):
687+
self.assertFalse(self.run_append(sct)['correct'])
688+
689+
def test_pass_np_call_exists(self):
690+
sct_payload = helper.run(self.data)
691+
self.assertTrue(sct_payload['correct'])
692+
693+
def test_pass_test_student_typed(self):
694+
self.run_pass(".test_student_typed(r'np\.array\(\[1,2,3\]\)')")
695+
696+
def test_fail_test_student_typed(self):
697+
self.data["DC_CODE"] = "np.array([1,2])"
698+
self.run_fail(".test_student_typed(r'np\.array\(\[1,2,3\]\)')")
699+
700+
def test_pass_func_has_equal_ast(self):
701+
self.run_pass(".has_equal_ast()")
702+
703+
def test_fail_func_has_equal_ast(self):
704+
self.data["DC_CODE"] = "np.array([1,2])"
705+
self.run_fail(".has_equal_ast()")
706+
707+
def test_pass_check_args_pos_0(self):
708+
self.run_pass(".check_args(0)")
709+
710+
def test_fail_check_args_pos_0(self):
711+
self.data["DC_CODE"] = "np.array()"
712+
self.run_fail(".check_args(0)")
713+
714+
def test_pass_pos_0_test_student_typed(self):
715+
self.run_pass(".check_args(0).test_student_typed(r'\[1,2,3\]')")
716+
717+
def test_fail_pos_0_test_student_typed(self):
718+
self.data["DC_CODE"] = "np.array([1,2])"
719+
self.run_fail(".check_args(0).test_student_typed(r'\[1,2,3\]')")
720+
721+
def test_pass_pos_0_has_equal_ast(self):
722+
self.run_pass(".check_args(0).has_equal_ast()")
723+
724+
def test_fail_pos_0_has_equal_ast(self):
725+
self.data["DC_CODE"] = "np.array([1,2])"
726+
self.run_fail(".check_args(0).has_equal_ast()")
727+
728+
def test_pass_pos_0_has_equal_value(self):
729+
self.run_pass(".check_args(0).has_equal_value()")
730+
731+
def test_fail_pos_0_has_equal_value(self):
732+
self.data["DC_CODE"] = "np.array([1,2])"
733+
self.run_fail(".check_args(0).has_equal_value()")
734+
735+
def test_pass_pos_0_inline_if_body(self):
736+
self.data["DC_CODE"] = "np.array([1,2,3] if True else [1])"
737+
self.data["DC_SOLUTION"] = "np.array([1,2,3] if False else [1])"
738+
self.run_pass(".check_args(0).check_if_exp(0).check_body().has_equal_ast()")
739+
740+
def test_fail_pos_0_inline_if_body(self):
741+
self.data["DC_CODE"] = "np.array([1,2,3] if True else [1])"
742+
self.data["DC_SOLUTION"] = "np.array([1,2] if False else [1])"
743+
self.run_fail(".check_args(0).check_if_exp(0).check_body().has_equal_ast()")
744+
745+
class TestCheckFunctionCases(unittest.TestCase):
746+
def setup_color(self):
747+
self.data = {
748+
'DC_PEC': "def f(*args, **kwargs): pass",
749+
'DC_CODE': "f(color = 'blue')"
750+
}
751+
self.data["DC_SOLUTION"] = self.data["DC_CODE"]
752+
753+
def test_pass_sig_false(self):
754+
self.setup_color()
755+
self.data['DC_SCT'] = "Ex().check_function('f', 0, signature=False).check_args('color').has_equal_ast()"
756+
757+
sct_payload = helper.run(self.data)
758+
self.assertTrue(sct_payload['correct'])
759+
760+
@unittest.skip("TODO: implement override")
761+
def test_pass_sig_false_override(self):
762+
self.setup_color()
763+
self.data["DC_SCT"].replace('color', 'c')
764+
self.data['DC_SCT'] = """
765+
Ex().check_function('f', 0, signature=False).override("f(c = 'blue')").check_args('c').has_equal_ast()
766+
"""
767+
768+
sct_payload = helper.run(self.data)
769+
self.assertTrue(sct_payload['correct'])
770+
662771

663772
class TestFunctionComplexArgs(unittest.TestCase):
664773
def setUp(self):
@@ -702,7 +811,5 @@ def test_fail_undillable_args(self):
702811
self.assertFalse(sct_payload['correct'])
703812

704813

705-
706-
707814
if __name__ == "__main__":
708815
unittest.main()

tests/test_test_function_v2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,20 +669,34 @@ def setUp(self):
669669
'''
670670
}
671671

672+
self.SPEC2_SCT = """
673+
Ex().check_function('pandas.DataFrame', 0, missing_msg = "notcalledmsg", expand_msg="")\
674+
.check_args('data', missing_msg='paramsnotmatchedmsg')
675+
"""
676+
672677
def test_step1(self):
673678
self.data["DC_CODE"] = ""
674679
sct_payload = helper.run(self.data)
675680
self.assertFalse(sct_payload['correct'])
676681
self.assertEqual('notcalledmsg', sct_payload['message'])
677682
helper.test_absent_lines(self, sct_payload)
678683

684+
def test_step1_spec2(self):
685+
self.data["DC_SCT"] = self.SPEC2_SCT
686+
self.test_step1()
687+
679688
def test_step2(self):
680689
self.data["DC_CODE"] = "df = pd.DataFrame(x=[1, 2, 3])"
681690
sct_payload = helper.run(self.data)
682691
self.assertFalse(sct_payload['correct'])
683692
self.assertEqual('paramsnotmatchedmsg', sct_payload['message'])
684693
helper.test_lines(self, sct_payload, 1, 1, 6, 30)
685694

695+
def test_step2_spec2(self):
696+
self.data["DC_SCT"] = self.SPEC2_SCT
697+
self.test_step2()
698+
699+
686700
def test_step3(self):
687701
self.data["DC_CODE"] = "df = pd.DataFrame(data=[1, 2, 3])"
688702
sct_payload = helper.run(self.data)

0 commit comments

Comments
 (0)