Skip to content

Commit 0f227b6

Browse files
committed
Inherit from protowhat State + init and to_child refactor
1 parent 882ebd4 commit 0f227b6

11 files changed

Lines changed: 129 additions & 221 deletions

pythonwhat/State.py

Lines changed: 91 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import ast
21
import inspect
3-
import string
42
from copy import copy
53
from functools import partialmethod
64
from pythonwhat.parsing import (
@@ -9,13 +7,13 @@
97
ObjectAccessParser,
108
parser_dict,
119
)
10+
from protowhat.State import State as ProtoState
1211
from protowhat.Feedback import InstructorError
1312
from pythonwhat.Feedback import Feedback
1413
from protowhat.Test import Fail
1514
from pythonwhat import signatures
1615
from pythonwhat.converters import get_manual_converters
1716
from collections.abc import Mapping
18-
from jinja2 import Template
1917
import asttokens
2018
from pythonwhat.utils_ast import wrap_in_module
2119

@@ -41,7 +39,7 @@ def __len__(self):
4139
return len(self._items)
4240

4341

44-
class State:
42+
class State(ProtoState):
4543
"""State of the SCT environment.
4644
4745
This class holds all information relevevant to test the correctness of an exercise.
@@ -56,205 +54,114 @@ class State:
5654

5755
def __init__(
5856
self,
59-
student_context=None,
60-
solution_context=None,
61-
student_env=None,
62-
solution_env=None,
63-
student_parts=None,
64-
solution_parts=None,
57+
student_code,
58+
solution_code,
59+
pre_exercise_code,
60+
student_process,
61+
solution_process,
62+
raw_student_output,
63+
# solution output
64+
reporter,
65+
force_diagnose=False,
6566
highlight=None,
6667
highlighting_disabled=None,
6768
messages=None,
68-
force_diagnose=False,
69-
**kwargs
69+
parent_state=None,
70+
pre_exercise_ast=None,
71+
student_ast=None,
72+
solution_ast=None,
73+
student_ast_tokens=None,
74+
solution_ast_tokens=None,
75+
student_parts=None,
76+
solution_parts=None,
77+
student_context=Context(),
78+
solution_context=Context(),
79+
student_env=Context(),
80+
solution_env=Context(),
7081
):
82+
args = locals().copy()
83+
self.params = list()
7184

72-
# Set basic fields from kwargs
73-
self.__dict__.update(kwargs)
74-
75-
self.student_parts = student_parts
76-
self.solution_parts = solution_parts
77-
self.messages = messages if messages else []
78-
self.force_diagnose = force_diagnose
79-
80-
# parse code if didn't happen yet
81-
if not hasattr(self, "student_tree"):
82-
self.student_tree_tokens, self.student_tree = self.parse(self.student_code)
83-
84-
if not hasattr(self, "solution_tree"):
85-
self.solution_tree_tokens, self.solution_tree = self.parse(
86-
self.solution_code, test=False
87-
)
85+
for k, v in args.items():
86+
if k not in ["self", "args"]:
87+
self.params.append(k)
88+
setattr(self, k, v)
8889

89-
if not hasattr(self, "pre_exercise_tree"):
90-
_, self.pre_exercise_tree = self.parse(self.pre_exercise_code, test=False)
90+
if pre_exercise_ast is None:
91+
_, self.pre_exercise_ast = self.parse_internal(pre_exercise_code)
9192

92-
self.ast_dispatcher = Dispatcher(self.pre_exercise_tree)
93+
self.ast_dispatcher = self.get_dispatcher() # use updated pre_exercise_ast
9394

94-
if not hasattr(self, "parent_state"):
95-
self.parent_state = None
95+
# parse code if didn't happen yet
96+
if student_ast is None:
97+
self.student_ast = self.parse(student_code)
98+
if solution_ast is None:
99+
self.solution_ast = self.parse(solution_code, test=False)
96100

97-
self.student_context = (
98-
Context(student_context) if student_context is None else student_context
99-
)
100-
self.solution_context = (
101-
Context(solution_context) if solution_context is None else solution_context
102-
)
103-
self.student_env = Context(student_env) if student_env is None else student_env
104-
self.solution_env = (
105-
Context(solution_env) if solution_env is None else solution_env
106-
)
101+
if highlight is None and parent_state:
102+
self.highlight = self.student_ast
107103

108-
self.highlight = (
109-
self.student_tree if (not highlight) and self.parent_state else highlight
110-
)
111-
self.highlighting_disabled = highlighting_disabled
104+
self.messages = messages if messages else []
112105

113106
self.converters = get_manual_converters() # accessed only from root state
114107

115108
self.manual_sigs = None
116-
self._parser_cache = {}
117109

118110
def get_manual_sigs(self):
119111
if self.manual_sigs is None:
120112
self.manual_sigs = signatures.get_manual_sigs()
121113

122114
return self.manual_sigs
123115

124-
def build_message(self, tail="", fmt_kwargs=None, append=True):
125-
126-
if not fmt_kwargs:
127-
fmt_kwargs = {}
128-
out_list = []
129-
# add trailing message to msg list
130-
msgs = self.messages[:] + [{"msg": tail or "", "kwargs": fmt_kwargs}]
131-
# format messages in list, by iterating over previous, current, and next message
132-
for prev_d, d, next_d in zip([{}, *msgs[:-1]], msgs, [*msgs[1:], {}]):
133-
tmp_kwargs = {
134-
"parent": prev_d.get("kwargs"),
135-
"child": next_d.get("kwargs"),
136-
"this": d["kwargs"],
137-
**d["kwargs"],
138-
}
139-
# don't bother appending if there is no message
140-
if not d["msg"]:
141-
continue
142-
out = Template(d["msg"].replace("__JINJA__:", "")).render(**tmp_kwargs)
143-
out_list.append(out)
144-
145-
# if highlighting info is available, don't put all expand messages
146-
if self.highlight and not self.highlighting_disabled:
147-
out_list = out_list[-3:]
148-
149-
if append:
150-
return "".join(out_list)
151-
else:
152-
return out_list[-1]
153-
154-
def do_test(self, test):
155-
return self.reporter.do_test(test)
156-
157-
def to_child(
158-
self,
159-
student_subtree=None,
160-
solution_subtree=None,
161-
student_context=None,
162-
solution_context=None,
163-
student_env=None,
164-
solution_env=None,
165-
student_parts=None,
166-
solution_parts=None,
167-
highlight=None,
168-
highlighting_disabled=None,
169-
append_message="",
170-
node_name="",
171-
):
116+
def to_child(self, append_message="", node_name="", **kwargs):
172117
"""Dive into nested tree.
173118
174119
Set the current state as a state with a subtree of this syntax tree as
175120
student tree and solution tree. This is necessary when testing if statements or
176121
for loops for example.
177122
"""
123+
base_kwargs = {attr: getattr(self, attr) for attr in self.params if attr not in ['highlight']}
178124

179-
if isinstance(student_subtree, list):
180-
student_subtree = wrap_in_module(student_subtree)
181-
if isinstance(solution_subtree, list):
182-
solution_subtree = wrap_in_module(solution_subtree)
183-
184-
# get new contexts
185-
if solution_context is not None:
186-
solution_context = self.solution_context.update_ctx(solution_context)
187-
else:
188-
solution_context = self.solution_context
189-
190-
if student_context is not None:
191-
student_context = self.student_context.update_ctx(student_context)
192-
else:
193-
student_context = self.student_context
125+
if not isinstance(append_message, dict):
126+
append_message = {"msg": append_message, "kwargs": {}}
194127

195-
# get new envs
196-
if solution_env is not None:
197-
solution_env = self.solution_env.update_ctx(solution_env)
198-
else:
199-
solution_env = self.solution_env
128+
kwargs["messages"] = [*self.messages, append_message]
129+
kwargs["parent_state"] = self
200130

201-
if student_env is not None:
202-
student_env = self.student_env.update_ctx(student_env)
203-
else:
204-
student_env = self.student_env
131+
def update_kwarg(name, func):
132+
kwargs[name] = func(kwargs[name])
205133

206-
if highlighting_disabled is None:
207-
highlighting_disabled = self.highlighting_disabled
134+
def update_context(name):
135+
update_kwarg(name, getattr(self, name).update_ctx)
208136

209-
if not isinstance(append_message, dict):
210-
append_message = {"msg": append_message, "kwargs": {}}
137+
if isinstance(kwargs.get("student_ast", None), list):
138+
update_kwarg("student_ast", wrap_in_module)
139+
if isinstance(kwargs.get("solution_ast", None), list):
140+
update_kwarg("solution_ast", wrap_in_module)
211141

212-
messages = [*self.messages, append_message]
213-
214-
if not (solution_subtree and student_subtree):
215-
return self._update(
216-
student_context=student_context,
217-
solution_context=solution_context,
218-
student_env=student_env,
219-
solution_env=solution_env,
220-
highlight=highlight,
221-
highlighting_disabled=highlighting_disabled,
222-
messages=messages,
142+
if "student_ast" in kwargs:
143+
kwargs["student_code"] = self.student_ast_tokens.get_text(
144+
kwargs["student_ast"]
145+
)
146+
if "solution_ast" in kwargs:
147+
kwargs["solution_code"] = self.solution_ast_tokens.get_text(
148+
kwargs["solution_ast"]
223149
)
224150

225-
klass = State if not node_name else self.SUBCLASSES[node_name]
226-
child = klass(
227-
student_code=self.student_tree_tokens.get_text(student_subtree),
228-
solution_code=self.solution_tree_tokens.get_text(solution_subtree),
229-
student_tree_tokens=self.student_tree_tokens,
230-
solution_tree_tokens=self.solution_tree_tokens,
231-
pre_exercise_code=self.pre_exercise_code,
232-
student_context=student_context,
233-
solution_context=solution_context,
234-
student_env=student_env,
235-
solution_env=solution_env,
236-
student_process=self.student_process,
237-
solution_process=self.solution_process,
238-
raw_student_output=self.raw_student_output,
239-
pre_exercise_tree=self.pre_exercise_tree,
240-
student_tree=student_subtree,
241-
solution_tree=solution_subtree,
242-
student_parts=student_parts,
243-
solution_parts=solution_parts,
244-
highlight=highlight,
245-
highlighting_disabled=highlighting_disabled,
246-
messages=messages,
247-
parent_state=self,
248-
reporter=self.reporter,
249-
force_diagnose=self.force_diagnose,
250-
)
251-
return child
151+
# get new contexts
152+
if "solution_context" in kwargs:
153+
update_context("solution_context")
154+
if "student_context" in kwargs:
155+
update_context("student_context")
156+
157+
# get new envs
158+
if "solution_env" in kwargs:
159+
update_context("solution_env")
160+
if "student_env" in kwargs:
161+
update_context("student_env")
252162

253-
def _update(self, **kwargs):
254-
"""Return a copy of set, setting kwargs as attributes"""
255-
child = copy(self)
256-
for k, v in kwargs.items():
257-
setattr(child, k, v)
163+
klass = self.SUBCLASSES[node_name] if node_name else State
164+
child = klass(**{**base_kwargs, **kwargs})
258165
return child
259166

260167
def has_different_processes(self):
@@ -337,16 +244,26 @@ def parse_internal(code):
337244
def parse(self, text, test=True):
338245
if test:
339246
parse_method = self.parse_external
247+
token_attr = 'student_ast_tokens'
340248
else:
341249
parse_method = self.parse_internal
250+
token_attr = 'solution_ast_tokens'
342251

343-
return parse_method(text)
252+
tokens, ast = parse_method(text)
253+
setattr(self, token_attr, tokens)
254+
255+
return ast
256+
257+
def get_dispatcher(self):
258+
return Dispatcher(self.pre_exercise_ast)
344259

345260

346261
class Dispatcher:
347-
def __init__(self, pre_exercise_tree):
262+
def __init__(self, pre_exercise_ast):
348263
self._parser_cache = dict()
349-
self.pre_exercise_mappings = self._getx(FunctionParser, "mappings", pre_exercise_tree)
264+
self.pre_exercise_mappings = self._getx(
265+
FunctionParser, "mappings", pre_exercise_ast
266+
)
350267

351268
def __call__(self, name, node):
352269
return getattr(self, name)(node)
@@ -367,7 +284,10 @@ def _getx(self, Parser, ext_attr, tree):
367284
# otherwise, run parser over tree
368285
p = Parser()
369286
# set mappings for parsers that inspect attribute access
370-
if ext_attr != "mappings" and Parser in [FunctionParser, ObjectAccessParser]:
287+
if ext_attr != "mappings" and Parser in [
288+
FunctionParser,
289+
ObjectAccessParser,
290+
]:
371291
p.mappings = self.pre_exercise_mappings.copy()
372292
# run parser
373293
p.visit(tree)

pythonwhat/checks/check_funcs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def part_to_child(stu_part, sol_part, append_message, state, node_name=None):
2727
# if the parts are dictionaries, use to deck out child state
2828
if all(isinstance(p, dict) for p in [stu_part, sol_part]):
2929
child_state = state.to_child(
30-
student_subtree=stu_part["node"],
31-
solution_subtree=sol_part["node"],
30+
student_ast=stu_part["node"],
31+
solution_ast=sol_part["node"],
3232
student_context=stu_part.get("target_vars"),
3333
solution_context=sol_part.get("target_vars"),
3434
student_parts=stu_part,
@@ -40,8 +40,8 @@ def part_to_child(stu_part, sol_part, append_message, state, node_name=None):
4040
else:
4141
# otherwise, assume they are just nodes
4242
child_state = state.to_child(
43-
student_subtree=stu_part,
44-
solution_subtree=sol_part,
43+
student_ast=stu_part,
44+
solution_ast=sol_part,
4545
append_message=append_message,
4646
node_name=node_name,
4747
)
@@ -123,8 +123,8 @@ def check_node(
123123
if expand_msg is None:
124124
expand_msg = "Check the {{typestr}}. "
125125

126-
stu_out = state.ast_dispatcher(name, state.student_tree)
127-
sol_out = state.ast_dispatcher(name, state.solution_tree)
126+
stu_out = state.ast_dispatcher(name, state.student_ast)
127+
sol_out = state.ast_dispatcher(name, state.solution_ast)
128128

129129
# check if there are enough nodes for index
130130
fmt_kwargs = {

pythonwhat/checks/check_function.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def check_function(
9999
if params_not_matched_msg is None:
100100
params_not_matched_msg = SIG_ISSUE_MSG
101101

102-
stu_out = state.ast_dispatcher("function_calls", state.student_tree)
103-
sol_out = state.ast_dispatcher("function_calls", state.solution_tree)
102+
stu_out = state.ast_dispatcher("function_calls", state.student_ast)
103+
sol_out = state.ast_dispatcher("function_calls", state.solution_ast)
104104

105-
student_mappings = state.ast_dispatcher("mappings", state.student_tree)
105+
student_mappings = state.ast_dispatcher("mappings", state.student_ast)
106106

107107
fmt_kwargs = {
108108
"times": get_times(index + 1),

0 commit comments

Comments
 (0)