-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathtest_selectors.py
More file actions
123 lines (90 loc) · 3.54 KB
/
test_selectors.py
File metadata and controls
123 lines (90 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from protowhat.selectors import Selector, Dispatcher
from sqlwhat.State import State, PARSER_MODULES
import importlib
from protowhat.Reporter import Reporter
from protowhat.failure import TestFail as TF
import pytest
@pytest.fixture
def ast():
return importlib.import_module(PARSER_MODULES["postgresql"])
@pytest.fixture
def dispatcher(ast):
return Dispatcher.from_module(ast)
def test_selector_standalone():
from ast import Expr, Num # use python's builtin ast library
Expr._priority = 0
Num._priority = 1
node = Expr(value=Num(n=1))
sel = Selector(Num)
sel.visit(node)
assert isinstance(sel.out[0], Num)
def test_selector_on_self(ast):
# Don't simplify terminals
ast.Terminal.DEBUG = True
terminal = ast.Terminal.from_text("test")
sel = Selector(ast.Terminal)
sel.visit(terminal)
assert sel.out[0] == terminal
# tests using actual parsed ASTs ----------------------------------------------
def build_and_run(sql_expr, ast_class, ast_mod, priority=None):
tree = ast_mod.parse(sql_expr)
sel = Selector(ast_class, priority=priority)
sel.visit(tree)
return sel.out
def test_selector_on_script(ast):
out = build_and_run("SELECT id FROM artists", ast.SelectStmt, ast)
assert len(out) == 1
assert type(out[0]) == ast.SelectStmt
def test_selector_default_priority(ast):
out = build_and_run("SELECT id FROM artists", ast.Identifier, ast)
assert len(out) == 0
def test_selector_set_high_priority(ast):
out = build_and_run("SELECT id FROM artists", ast.Identifier, ast, priority=999)
assert len(out) == 2
assert all(type(v) == ast.Identifier for v in out)
def test_selector_set_low_priority(ast):
out = build_and_run("SELECT id FROM artists", ast.Identifier, ast, priority=0)
assert len(out) == 0
def test_selector_omits_subquery(ast):
out = build_and_run(
"SELECT a FROM x WHERE a = (SELECT b FROM y)", ast.SelectStmt, ast
)
assert len(out) == 1
assert all(type(v) == ast.SelectStmt for v in out)
assert out[0].target_list[0].fields == ["a"]
def test_selector_includes_subquery(ast):
out = build_and_run(
"SELECT a FROM x WHERE a = (SELECT b FROM y)", ast.SelectStmt, ast, priority=999
)
select1 = out[1]
select2 = ast.parse(
"SELECT b FROM y", start="subquery"
) # subquery is the parser rule for select statements
assert repr(select1) == repr(select2)
def test_selector_head(ast):
bin_expr = ast.parse("1 + 2 + 3", "expression")
sel = Selector(ast.BinaryExpr)
sel.visit(bin_expr)
assert len(sel.out) == 1
assert sel.out[0].right == "3"
sel2 = Selector(ast.BinaryExpr)
sel2.visit(sel.out[0], head=True)
assert len(sel2.out) == 1
assert sel2.out[0].left == "1"
def test_dispatch_select(dispatcher, ast):
tree = ast.parse("SELECT id FROM artists")
selected = dispatcher.find("SelectStmt", tree)[0]
assert type(selected) == ast.SelectStmt
def test_selecting_base_nodes(ast):
query = "SELECT a FROM b"
antlr_tree = ast.parse_ast(ast.grammar, query, start="sql_script")
ast_tree = ast.process_tree(antlr_tree)
base_selector = Selector(ast.AstNode, "Query_block", strict=False, priority=3)
base_selector.visit(ast_tree)
found = base_selector.out
assert len(found) == 1
assert isinstance(found[0], ast.AstNode)
assert found[0].__class__.__name__ == "Query_block"
base_dispatch = Dispatcher(ast.AstNode, ast_mod=ast)
select_stmt = base_dispatch.find("Query_block", ast_tree, priority=3)[0]
assert select_stmt == found[0]