Skip to content

Commit 9b68b54

Browse files
committed
parsing join
1 parent 902ec09 commit 9b68b54

3 files changed

Lines changed: 116 additions & 34 deletions

File tree

abstra_json_sql/ast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class Wildcard(Ast):
149149
class Join(Ast):
150150
table: str
151151
table_alias: str
152-
join_type: Literal["INNER", "LEFT", "RIGHT"]
152+
join_type: Literal["INNER", "LEFT", "RIGHT", "FULL", "CROSS", "NATURAL"]
153153
on: Expression
154154

155155

abstra_json_sql/parser.py

Lines changed: 109 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
OrderField,
3333
PlusExpression,
3434
MinusExpression,
35+
Join,
3536
MultiplyExpression,
3637
DivideExpression,
3738
Expression,
@@ -79,6 +80,39 @@ def parse_expression(tokens: List[Token]) -> Tuple[Expression, List[Token]]:
7980
while tokens:
8081
next_token = tokens[0]
8182
tokens = tokens[1:]
83+
84+
if next_token.type == "keyword":
85+
if next_token.value.upper() == "AND":
86+
left = stack.pop()
87+
right, tokens = parse_expression(tokens)
88+
stack.append(AndExpression(left=left, right=right))
89+
elif next_token.value.upper() == "OR":
90+
left = stack.pop()
91+
right, tokens = parse_expression(tokens)
92+
stack.append(OrExpression(left=left, right=right))
93+
elif next_token.value.upper() == "NOT":
94+
tokens = tokens[1:]
95+
expression, tokens = parse_expression(tokens)
96+
stack.append(NotExpression(expression=expression))
97+
elif next_token.value.upper() == "TRUE":
98+
stack.append(TrueExpression())
99+
tokens = tokens[1:]
100+
elif next_token.value.upper() == "FALSE":
101+
stack.append(FalseExpression())
102+
tokens = tokens[1:]
103+
elif next_token.value.upper() == "IS":
104+
left = stack.pop()
105+
right, tokens = parse_expression(tokens)
106+
stack.append(IsExpression(left=left, right=right, is_not=False))
107+
elif next_token.value.upper() == "IS NOT":
108+
left = stack.pop()
109+
right, tokens = parse_expression(tokens)
110+
stack.append(IsExpression(left=left, right=right, is_not=True))
111+
elif next_token.value.upper() == "NULL":
112+
stack.append(NullExpression())
113+
else:
114+
tokens = [next_token] + tokens
115+
break
82116
if next_token.type == "int":
83117
stack.append(IntExpression(value=int(next_token.value)))
84118
elif next_token.type == "float":
@@ -158,38 +192,6 @@ def parse_expression(tokens: List[Token]) -> Tuple[Expression, List[Token]]:
158192
stack.append(NotEqualExpression(left=left, right=right))
159193
else:
160194
raise ValueError(f"Unknown operator: {operator}")
161-
elif next_token.type == "keyword":
162-
if next_token.value.upper() == "AND":
163-
left = stack.pop()
164-
right, tokens = parse_expression(tokens)
165-
stack.append(AndExpression(left=left, right=right))
166-
elif next_token.value.upper() == "OR":
167-
left = stack.pop()
168-
right, tokens = parse_expression(tokens)
169-
stack.append(OrExpression(left=left, right=right))
170-
elif next_token.value.upper() == "NOT":
171-
tokens = tokens[1:]
172-
expression, tokens = parse_expression(tokens)
173-
stack.append(NotExpression(expression=expression))
174-
elif next_token.value.upper() == "TRUE":
175-
stack.append(TrueExpression())
176-
tokens = tokens[1:]
177-
elif next_token.value.upper() == "FALSE":
178-
stack.append(FalseExpression())
179-
tokens = tokens[1:]
180-
elif next_token.value.upper() == "IS":
181-
left = stack.pop()
182-
right, tokens = parse_expression(tokens)
183-
stack.append(IsExpression(left=left, right=right, is_not=False))
184-
elif next_token.value.upper() == "IS NOT":
185-
left = stack.pop()
186-
right, tokens = parse_expression(tokens)
187-
stack.append(IsExpression(left=left, right=right, is_not=True))
188-
elif next_token.value.upper() == "NULL":
189-
stack.append(NullExpression())
190-
else:
191-
tokens = [next_token] + tokens
192-
break
193195
else:
194196
tokens = [next_token] + tokens
195197
break
@@ -238,6 +240,68 @@ def parse_group_by(tokens: List[Token]) -> Tuple[Optional[GroupBy], List[Token]]
238240
group_fields.append(exp)
239241
return GroupBy(fields=group_fields), tokens
240242

243+
def parse_join(tokens: List[Token]) -> Tuple[Optional[Join], List[Token]]:
244+
if len(tokens) == 0:
245+
return None, tokens
246+
247+
# JOIN
248+
if tokens[0].type == "keyword" and tokens[0].value.upper() in ["INNER JOIN", "JOIN"]:
249+
join_type = "INNER"
250+
tokens = tokens[1:]
251+
elif tokens[0].type == "keyword" and tokens[0].value.upper() in ["LEFT JOIN", "LEFT OUTER JOIN"]:
252+
join_type = "LEFT"
253+
tokens = tokens[1:]
254+
elif tokens[0].type == "keyword" and tokens[0].value.upper() in ["RIGHT JOIN", "RIGHT OUTER JOIN"]:
255+
join_type = "RIGHT"
256+
tokens = tokens[1:]
257+
elif tokens[0].type == "keyword" and tokens[0].value.upper() in ["FULL JOIN", "FULL OUTER JOIN"]:
258+
join_type = "FULL"
259+
tokens = tokens[1:]
260+
elif tokens[0].type == "keyword" and tokens[0].value.upper() in ["CROSS JOIN"]:
261+
join_type = "CROSS"
262+
tokens = tokens[1:]
263+
elif tokens[0].type == "keyword" and tokens[0].value.upper() in ["NATURAL JOIN"]:
264+
join_type = "NATURAL"
265+
tokens = tokens[1:]
266+
else:
267+
return None, tokens
268+
269+
# Table
270+
table = tokens[0]
271+
assert table.type == "name", f"Expected table name, got {table}"
272+
tokens = tokens[1:]
273+
274+
# AS
275+
if (
276+
len(tokens) > 0
277+
and tokens[0].type == "keyword"
278+
and tokens[0].value.upper() == "AS"
279+
):
280+
tokens = tokens[1:]
281+
alias_token = tokens[0]
282+
assert alias_token.type == "name", f"Expected alias name, got {alias_token}"
283+
tokens = tokens[1:]
284+
else:
285+
alias_token = None
286+
287+
# ON
288+
if (
289+
len(tokens) > 0
290+
and tokens[0].type == "keyword"
291+
and tokens[0].value.upper() == "ON"
292+
):
293+
tokens = tokens[1:]
294+
on_expression, tokens = parse_expression(tokens)
295+
else:
296+
raise ValueError("Expected ON clause after JOIN")
297+
298+
return Join(
299+
table=table.value,
300+
table_alias=alias_token.value if alias_token else None,
301+
join_type=join_type,
302+
on=on_expression,
303+
), tokens
304+
241305

242306
def parse_from(tokens: List[Token]) -> Tuple[Optional[From], List[Token]]:
243307
if len(tokens) == 0:
@@ -261,8 +325,20 @@ def parse_from(tokens: List[Token]) -> Tuple[Optional[From], List[Token]]:
261325
assert alias_token.type == "name", f"Expected alias name, got {alias_token}"
262326
tokens = tokens[1:]
263327
return From(table=table.value, alias=alias_token.value), tokens
264-
else:
328+
329+
join: List[Join] = []
330+
while True:
331+
if len(tokens) == 0:
332+
break
333+
j, tokens = parse_join(tokens)
334+
if j is None:
335+
break
336+
join.append(j)
337+
338+
if len(join) == 0:
265339
return From(table=table.value), tokens
340+
else:
341+
return From(table=table.value, join=join), tokens
266342

267343

268344
def parse_fields(tokens: List[Token]) -> Tuple[List[SelectField], List[Token]]:

abstra_json_sql/tokens.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"WHERE",
2323
"AND",
2424
"AS",
25+
"ON",
2526
"OR",
2627
"NOT",
2728
"IN",
@@ -40,13 +41,18 @@
4041
"ASC",
4142
"DESC",
4243
"INNER JOIN",
44+
"RIGHT OUTER JOIN",
45+
"LEFT OUTER JOIN",
4346
"LEFT JOIN",
4447
"RIGHT JOIN",
4548
"FULL JOIN",
4649
"CROSS JOIN",
4750
"NATURAL JOIN",
4851
"JOIN",
4952
"LIMIT",
53+
"TRUE",
54+
"FALSE",
55+
"NULL"
5056
]
5157

5258

0 commit comments

Comments
 (0)