Skip to content

Commit 9bc2890

Browse files
committed
func arg for has_expr for custom equality check
1 parent 1f57ef2 commit 9bc2890

4 files changed

Lines changed: 30 additions & 4 deletions

File tree

pythonwhat/Test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,17 @@ class EqualTest(Test):
144144
result (bool): True if the test succeed, False if it failed. None if it hasn't been tested yet.
145145
"""
146146

147-
def __init__(self, obj1, obj2, feedback):
147+
def __init__(self, obj1, obj2, feedback, func = None):
148148
super().__init__(feedback)
149149
self.obj1 = obj1
150150
self.obj2 = obj2
151+
self.func = func if func is not None else is_equal
151152

152153
def specific_test(self):
153154
"""
154155
Perform the actual test. result is set to False if the objects differ, True otherwise.
155156
"""
156-
self.result = is_equal(self.obj1, self.obj2)
157+
self.result = self.func(self.obj1, self.obj2)
157158

158159

159160

pythonwhat/check_funcs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def call(args,
358358
error_msg=MSG_CALL_ERROR,
359359
# TODO kept for backwards compatibility in test_function_definition/lambda
360360
argstr='',
361+
func=None,
361362
state=None, **kwargs):
362363
rep = Reporter.active_reporter
363364
test_type = ('value', 'output', 'error')
@@ -389,7 +390,7 @@ def call(args,
389390

390391
# incorrect result
391392
_msg = state.build_message(incorrect_msg, fmt_kwargs)
392-
rep.do_test(EqualTest(eval_sol, eval_stu, Feedback(_msg, stu_node)))
393+
rep.do_test(EqualTest(eval_sol, eval_stu, Feedback(_msg, stu_node), func))
393394

394395
return state
395396

@@ -452,6 +453,7 @@ def has_expr(incorrect_msg="__JINJA__:Unexpected expression {{test}}: expected `
452453
name=None,
453454
highlight=None,
454455
copy=True,
456+
func=None,
455457
state=None,
456458
test=None):
457459
"""Run student and solution code, compare returned value, printed output, or errors.
@@ -537,7 +539,7 @@ def has_expr(incorrect_msg="__JINJA__:Unexpected expression {{test}}: expected `
537539

538540
# test equality of results
539541
_msg = state.build_message(incorrect_msg, fmt_kwargs)
540-
rep.do_test(EqualTest(eval_stu, eval_sol, Feedback(_msg, highlight)))
542+
rep.do_test(EqualTest(eval_stu, eval_sol, Feedback(_msg, highlight), func))
541543

542544
return state
543545

tests/test_test_expression_result.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,20 @@ def test_test_expression_result_copy_pass(self):
4141
sct_payload = helper.run(self.data)
4242
self.assertTrue(sct_payload['correct'])
4343

44+
def test_test_custom_equality_func(self):
45+
self.data["DC_SOLUTION"] = "a = [1.01]"
46+
self.data["DC_CODE"] = "a = [1.011]"
47+
self.data["DC_SCT"] = "import numpy as np; Ex().check_object('a').has_equal_value(func = lambda x, y: np.allclose(x, y, atol = .001))"
48+
sct_payload = helper.run(self.data)
49+
self.assertTrue(sct_payload['correct'])
50+
51+
def test_test_custom_equality_func_fail(self):
52+
self.data["DC_SOLUTION"] = "a = [1.01]"
53+
self.data["DC_CODE"] = "a = [1.011]"
54+
self.data["DC_SCT"] = "import numpy as np; Ex().check_object('a').has_equal_value(func = lambda x, y: np.allclose(x, y, atol = .0001))"
55+
sct_payload = helper.run(self.data)
56+
self.assertFalse(sct_payload['correct'])
57+
58+
4459
if __name__ == "__main__":
4560
unittest.main()

tests/test_test_function_definition.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ def test_step_x_spec2_str_call(self):
2727
self.data['DC_SCT'] = """
2828
(Ex().check_function_def('test').call("f(1,2)", 'value').call("f(1,2)", 'output')
2929
.call("f('a','b')", 'error'))
30+
"""
31+
self.test_step_x()
32+
33+
def test_step_x_spec2_func_arg(self):
34+
self.data['DC_SCT'] = """
35+
import numpy as np
36+
Ex().check_function_def('test').call("f(1,2)", func = lambda x, y: np.allclose(x, y))
37+
3038
"""
3139
self.test_step_x()
3240

0 commit comments

Comments
 (0)