Skip to content

Commit 6661f14

Browse files
committed
functions
1 parent f65cc2c commit 6661f14

5 files changed

Lines changed: 78 additions & 16 deletions

File tree

abstra_json_sql/apply.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
)
3030

3131

32+
def is_aggregate_function(name: str) -> bool:
33+
# Placeholder for aggregate function check
34+
return name.lower() in ["sum", "avg", "count", "min", "max"]
35+
36+
3237
def apply_expression(expression: Expression, ctx: dict):
3338
if isinstance(expression, StringExpression):
3439
return expression.value
@@ -151,9 +156,18 @@ def apply_expression(expression: Expression, ctx: dict):
151156
f"Unsupported types for less than or equal: {type(left_value)}, {type(right_value)}"
152157
)
153158
elif isinstance(expression, FunctionCallExpression):
154-
raise NotImplementedError(
155-
f"Function call expressions are not implemented: {expression.name}"
156-
)
159+
if is_aggregate_function(expression.name):
160+
raise NotImplementedError(
161+
f"Function call expressions are not implemented: {expression.name}"
162+
)
163+
else:
164+
args = [apply_expression(arg, ctx) for arg in expression.args]
165+
if expression.name == "lower":
166+
return args[0].lower()
167+
elif expression.name == "upper":
168+
return args[0].upper()
169+
else:
170+
raise ValueError(f"Unknown function: {expression.name}")
157171
else:
158172
raise ValueError(f"Unsupported expression type: {type(expression)}")
159173

abstra_json_sql/eval_test.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from unittest import TestCase
22
from .eval import eval_sql
3-
from .tables import InMemoryTables, Table
3+
from .tables import InMemoryTables, Table, Column
44

55

66
class TestEvalSQL(TestCase):
7-
def test_eval_sql(self):
7+
def test_sql(self):
88
code = "select 1+1"
99
tables = InMemoryTables(
1010
tables=[],
@@ -13,7 +13,7 @@ def test_eval_sql(self):
1313
result = eval_sql(code=code, tables=tables, ctx=ctx)
1414
self.assertEqual(result, [{"?column?": 2}])
1515

16-
def test_eval_select_alias(self):
16+
def test_select_alias(self):
1717
code = "select 1+1 as a"
1818
tables = InMemoryTables(
1919
tables=[],
@@ -22,21 +22,54 @@ def test_eval_select_alias(self):
2222
result = eval_sql(code=code, tables=tables, ctx=ctx)
2323
self.assertEqual(result, [{"a": 2}])
2424

25-
def test_eval_aggregate_sum(self):
26-
code = "select sum(foo) from bar"
25+
def test_lower(self):
26+
code = "select lower(foo) from bar"
2727
tables = InMemoryTables(
2828
tables=[
2929
Table(
3030
name="bar",
31-
columns=["foo"],
32-
rows=[
33-
{"foo": 1},
34-
{"foo": 2},
35-
{"foo": 3},
31+
columns=[Column(name="foo", type="text")],
32+
data=[
33+
{"foo": "AAA"},
34+
{"foo": "BBB"},
35+
{"foo": "CCC"},
3636
],
3737
)
3838
],
3939
)
4040
ctx = {}
4141
result = eval_sql(code=code, tables=tables, ctx=ctx)
42-
self.assertEqual(result, [{"a": 6}])
42+
self.assertEqual(
43+
result,
44+
[
45+
{"lower": "aaa"},
46+
{"lower": "bbb"},
47+
{"lower": "ccc"},
48+
],
49+
)
50+
51+
def test_upper(self):
52+
code = "select upper(foo) from bar"
53+
tables = InMemoryTables(
54+
tables=[
55+
Table(
56+
name="bar",
57+
columns=[Column(name="foo", type="text")],
58+
data=[
59+
{"foo": "aaa"},
60+
{"foo": "bbb"},
61+
{"foo": "ccc"},
62+
],
63+
)
64+
],
65+
)
66+
ctx = {}
67+
result = eval_sql(code=code, tables=tables, ctx=ctx)
68+
self.assertEqual(
69+
result,
70+
[
71+
{"upper": "AAA"},
72+
{"upper": "BBB"},
73+
{"upper": "CCC"},
74+
],
75+
)

abstra_json_sql/field_name.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .ast import SelectField, NameExpression
1+
from .ast import SelectField, NameExpression, FunctionCallExpression
22

33

44
def field_name(field: SelectField) -> str:
@@ -9,5 +9,7 @@ def field_name(field: SelectField) -> str:
99
return field.alias
1010
elif isinstance(field.expression, NameExpression):
1111
return field.expression.name
12+
elif isinstance(field.expression, FunctionCallExpression):
13+
return field.expression.name
1214
else:
1315
return "?column?"

abstra_json_sql/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,14 @@ def parse_expression(tokens: List[Token]) -> Tuple[Expression, List[Token]]:
7979
args = []
8080
while True:
8181
param_expression, tokens = parse_expression(tokens)
82+
args.append(param_expression)
8283
if tokens and tokens[0].type == "comma":
8384
tokens = tokens[1:]
8485
elif tokens and tokens[0].type == "paren_right":
8586
tokens = tokens[1:]
8687
break
8788
else:
8889
raise ValueError("Expected comma or closing parenthesis")
89-
args.append(param_expression)
9090

9191
stack.append(FunctionCallExpression(name=name_value, args=args))
9292
else:

abstra_json_sql/parser_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
From,
77
SelectWildcard,
88
PlusExpression,
9+
FunctionCallExpression,
910
SelectField,
1011
IntExpression,
1112
NameExpression,
@@ -136,6 +137,18 @@ def test_equal_expression(self):
136137
)
137138
self.assertEqual(tokens, [])
138139

140+
def test_function_call_expression(self):
141+
tokens = scan("SUM(foo)")
142+
ast, tokens = parse_expression(tokens)
143+
self.assertEqual(
144+
ast,
145+
FunctionCallExpression(
146+
name="SUM",
147+
args=[NameExpression(name="foo")],
148+
),
149+
)
150+
self.assertEqual(tokens, [])
151+
139152

140153
class FromTest(TestCase):
141154
def test_simple(self):

0 commit comments

Comments
 (0)