Skip to content

Commit 2abc3c3

Browse files
authored
Merge pull request #230 from datacamp/feature-custom-eq
func arg for has_expr for custom equality check
2 parents 1f57ef2 + 957f91e commit 2abc3c3

4 files changed

Lines changed: 31 additions & 4 deletions

File tree

pythonwhat/Test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,17 @@ class EqualTest(Test):
144144
result (bool): True if the test succeed, False if it failed. None if it hasn't been tested yet.
145145
"""
146146

147-
def __init__(self, obj1, obj2, feedback):
147+
def __init__(self, obj1, obj2, feedback, func = None):
148148
super().__init__(feedback)
149149
self.obj1 = obj1
150150
self.obj2 = obj2
151+
self.func = func if func is not None else is_equal
151152

152153
def specific_test(self):
153154
"""
154155
Perform the actual test. result is set to False if the objects differ, True otherwise.
155156
"""
156-
self.result = is_equal(self.obj1, self.obj2)
157+
self.result = self.func(self.obj1, self.obj2)
157158

158159

159160

pythonwhat/check_funcs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def call(args,
358358
error_msg=MSG_CALL_ERROR,
359359
# TODO kept for backwards compatibility in test_function_definition/lambda
360360
argstr='',
361+
func=None,
361362
state=None, **kwargs):
362363
rep = Reporter.active_reporter
363364
test_type = ('value', 'output', 'error')
@@ -389,7 +390,7 @@ def call(args,
389390

390391
# incorrect result
391392
_msg = state.build_message(incorrect_msg, fmt_kwargs)
392-
rep.do_test(EqualTest(eval_sol, eval_stu, Feedback(_msg, stu_node)))
393+
rep.do_test(EqualTest(eval_sol, eval_stu, Feedback(_msg, stu_node), func))
393394

394395
return state
395396

@@ -452,6 +453,7 @@ def has_expr(incorrect_msg="__JINJA__:Unexpected expression {{test}}: expected `
452453
name=None,
453454
highlight=None,
454455
copy=True,
456+
func=None,
455457
state=None,
456458
test=None):
457459
"""Run student and solution code, compare returned value, printed output, or errors.
@@ -483,6 +485,7 @@ def has_expr(incorrect_msg="__JINJA__:Unexpected expression {{test}}: expected `
483485
student and solution code. This could be thought of as post code.
484486
copy (bool): whether to try to deep copy objects in the environment, such as lists, that could
485487
accidentally be mutated. Disable to speed up SCTs. Disabling may lead to cryptic mutation issues.
488+
func: custom binary function of form f(stu_result, sol_result), for equality testing.
486489
"""
487490
rep = Reporter.active_reporter
488491

@@ -537,7 +540,7 @@ def has_expr(incorrect_msg="__JINJA__:Unexpected expression {{test}}: expected `
537540

538541
# test equality of results
539542
_msg = state.build_message(incorrect_msg, fmt_kwargs)
540-
rep.do_test(EqualTest(eval_stu, eval_sol, Feedback(_msg, highlight)))
543+
rep.do_test(EqualTest(eval_stu, eval_sol, Feedback(_msg, highlight), func))
541544

542545
return state
543546

tests/test_test_expression_result.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,20 @@ def test_test_expression_result_copy_pass(self):
4141
sct_payload = helper.run(self.data)
4242
self.assertTrue(sct_payload['correct'])
4343

44+
def test_test_custom_equality_func(self):
45+
self.data["DC_SOLUTION"] = "a = [1.01]"
46+
self.data["DC_CODE"] = "a = [1.011]"
47+
self.data["DC_SCT"] = "import numpy as np; Ex().check_object('a').has_equal_value(func = lambda x, y: np.allclose(x, y, atol = .001))"
48+
sct_payload = helper.run(self.data)
49+
self.assertTrue(sct_payload['correct'])
50+
51+
def test_test_custom_equality_func_fail(self):
52+
self.data["DC_SOLUTION"] = "a = [1.01]"
53+
self.data["DC_CODE"] = "a = [1.011]"
54+
self.data["DC_SCT"] = "import numpy as np; Ex().check_object('a').has_equal_value(func = lambda x, y: np.allclose(x, y, atol = .0001))"
55+
sct_payload = helper.run(self.data)
56+
self.assertFalse(sct_payload['correct'])
57+
58+
4459
if __name__ == "__main__":
4560
unittest.main()

tests/test_test_function_definition.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ def test_step_x_spec2_str_call(self):
2727
self.data['DC_SCT'] = """
2828
(Ex().check_function_def('test').call("f(1,2)", 'value').call("f(1,2)", 'output')
2929
.call("f('a','b')", 'error'))
30+
"""
31+
self.test_step_x()
32+
33+
def test_step_x_spec2_func_arg(self):
34+
self.data['DC_SCT'] = """
35+
import numpy as np
36+
Ex().check_function_def('test').call("f(1,2)", func = lambda x, y: np.allclose(x, y))
37+
3038
"""
3139
self.test_step_x()
3240

0 commit comments

Comments
 (0)