Skip to content

Commit 28f92f9

Browse files
committed
count
1 parent 6661f14 commit 28f92f9

7 files changed

Lines changed: 116 additions & 33 deletions

File tree

abstra_json_sql/apply.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import List, Dict, Optional
22
from .tables import ITablesSnapshot
3-
from .field_name import field_name
3+
from .field_name import field_name, expression_name
44
from .ast import (
55
Expression,
66
StringExpression,
@@ -25,6 +25,7 @@
2525
GreaterThanOrEqualExpression,
2626
LessThanExpression,
2727
LessThanOrEqualExpression,
28+
WildcardExpression,
2829
Limit,
2930
)
3031

@@ -157,14 +158,32 @@ def apply_expression(expression: Expression, ctx: dict):
157158
)
158159
elif isinstance(expression, FunctionCallExpression):
159160
if is_aggregate_function(expression.name):
160-
raise NotImplementedError(
161-
f"Function call expressions are not implemented: {expression.name}"
162-
)
161+
if expression.name.lower() == "count":
162+
assert len(expression.args) == 1, "Count function requires one argument"
163+
if isinstance(expression.args[0], WildcardExpression):
164+
return len(ctx["__grouped_rows"])
165+
elif isinstance(expression.args[0], NameExpression):
166+
return len(
167+
[
168+
row
169+
for row in ctx["__grouped_rows"]
170+
if expression.args[0].name in row
171+
and row[expression.args[0].name] is not None
172+
]
173+
)
174+
else:
175+
raise ValueError(f"Unknown aggregate function: {expression.name}")
163176
else:
164177
args = [apply_expression(arg, ctx) for arg in expression.args]
165178
if expression.name == "lower":
179+
assert isinstance(
180+
args[0], str
181+
), "lower function requires a string argument"
166182
return args[0].lower()
167183
elif expression.name == "upper":
184+
assert isinstance(
185+
args[0], str
186+
), "upper function requires a string argument"
168187
return args[0].upper()
169188
else:
170189
raise ValueError(f"Unknown function: {expression.name}")
@@ -196,14 +215,27 @@ def apply_order_by(order_by: OrderBy, data: List[dict], ctx: dict):
196215

197216
def apply_group_by(group_by: GroupBy, data: List[dict], ctx: dict):
198217
groups: Dict[tuple, list] = {}
199-
for row in data:
218+
for idx, row in enumerate(data):
200219
key = tuple(
201220
apply_expression(field, {**ctx, **row}) for field in group_by.fields
202221
)
203222
if key not in groups:
204223
groups[key] = []
205224
groups[key].append(row)
206-
return groups
225+
if not groups:
226+
return [
227+
{
228+
"__grouped_rows": data,
229+
**{expression_name(field): None for field in group_by.fields},
230+
}
231+
]
232+
return [
233+
{
234+
"__grouped_rows": rows,
235+
**{expression_name(field): key for field, key in zip(group_by.fields, key)},
236+
}
237+
for key, rows in groups.items()
238+
]
207239

208240

209241
def apply_limit(limit: Limit, data: List[dict], ctx: dict):
@@ -212,11 +244,19 @@ def apply_limit(limit: Limit, data: List[dict], ctx: dict):
212244
return data[start:end]
213245

214246

247+
def has_aggregation_fields(fields: List[SelectField]) -> bool:
248+
for field in fields:
249+
if isinstance(field.expression, FunctionCallExpression):
250+
if is_aggregate_function(field.expression.name):
251+
return True
252+
return False
253+
254+
215255
def apply_select_fields(fields: List[SelectField], data: List[dict], ctx: dict):
216256
return [
217257
{
218258
field_name(field) or field.expression: apply_expression(
219-
field.expression, {**ctx, **row}
259+
field.expression, {**ctx, "__grouped_rows": data, **row}
220260
)
221261
for field in fields
222262
}
@@ -247,12 +287,22 @@ def apply_from(
247287
return data
248288

249289

290+
def has_implicit_aggregation(fields: List[SelectField]) -> bool:
291+
for field in fields:
292+
if isinstance(field.expression, FunctionCallExpression):
293+
if is_aggregate_function(field.expression.name):
294+
return True
295+
return False
296+
297+
250298
def apply_select(select: Select, tables: ITablesSnapshot, ctx: dict):
251299
data = apply_from(select.from_part, tables, ctx)
252300
if select.where_part:
253301
data = apply_where(select.where_part, data, ctx)
254302
if select.group_part:
255303
data = apply_group_by(select.group_part, data, ctx)
304+
elif has_implicit_aggregation(select.field_parts):
305+
data = apply_group_by(GroupBy(fields=[]), data, ctx)
256306
if select.order_part:
257307
data = apply_order_by(select.order_part, data, ctx)
258308
if select.limit_part:

abstra_json_sql/apply_test.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,17 +202,23 @@ def test_group_by(self):
202202

203203
self.assertEqual(
204204
result,
205-
{
206-
("foo",): [
207-
{"name": "Alice", "team": "foo"},
208-
{"name": "Charlie", "team": "foo"},
209-
{"name": "Eve", "team": "foo"},
210-
],
211-
("bar",): [
212-
{"name": "Bob", "team": "bar"},
213-
{"name": "David", "team": "bar"},
214-
],
215-
},
205+
[
206+
{
207+
"team": "foo",
208+
"__grouped_rows": [
209+
{"name": "Alice", "team": "foo"},
210+
{"name": "Charlie", "team": "foo"},
211+
{"name": "Eve", "team": "foo"},
212+
],
213+
},
214+
{
215+
"team": "bar",
216+
"__grouped_rows": [
217+
{"name": "Bob", "team": "bar"},
218+
{"name": "David", "team": "bar"},
219+
],
220+
},
221+
],
216222
)
217223

218224

abstra_json_sql/ast.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ class NotEqualExpression(Expression):
101101
right: Expression
102102

103103

104+
@dataclass
105+
class WildcardExpression(Expression):
106+
pass
107+
108+
104109
@dataclass
105110
class Join(Ast):
106111
table: str
@@ -138,11 +143,6 @@ class OrderBy(Ast):
138143
fields: List[OrderField]
139144

140145

141-
@dataclass
142-
class SelectWildcard(Ast):
143-
pass
144-
145-
146146
@dataclass
147147
class GroupBy(Ast):
148148
fields: List[Expression]
@@ -156,7 +156,7 @@ class Limit(Ast):
156156

157157
@dataclass
158158
class Select(Command):
159-
field_parts: List[Union[SelectField, SelectWildcard]]
159+
field_parts: List[Union[SelectField, WildcardExpression]]
160160
from_part: Optional[From] = None
161161
where_part: Optional[Where] = None
162162
order_part: Optional[OrderBy] = None

abstra_json_sql/eval_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,22 @@ def test_upper(self):
7373
{"upper": "CCC"},
7474
],
7575
)
76+
77+
def test_count(self):
78+
code = "select count(foo) from bar"
79+
tables = InMemoryTables(
80+
tables=[
81+
Table(
82+
name="bar",
83+
columns=[Column(name="foo", type="text")],
84+
data=[
85+
{"foo": "aaa"},
86+
{"foo": "bbb"},
87+
{"foo": "ccc"},
88+
],
89+
)
90+
],
91+
)
92+
ctx = {}
93+
result = eval_sql(code=code, tables=tables, ctx=ctx)
94+
self.assertEqual(result, [{"count": 3}])

abstra_json_sql/field_name.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .ast import SelectField, NameExpression, FunctionCallExpression
1+
from .ast import SelectField, NameExpression, FunctionCallExpression, Expression
22

33

44
def field_name(field: SelectField) -> str:
@@ -7,9 +7,17 @@ def field_name(field: SelectField) -> str:
77
"""
88
if field.alias:
99
return field.alias
10-
elif isinstance(field.expression, NameExpression):
11-
return field.expression.name
12-
elif isinstance(field.expression, FunctionCallExpression):
13-
return field.expression.name
10+
else:
11+
return expression_name(field.expression)
12+
13+
14+
def expression_name(expression: Expression):
15+
"""
16+
Get the field name from an Expression object.
17+
"""
18+
if isinstance(expression, NameExpression):
19+
return expression.name
20+
elif isinstance(expression, FunctionCallExpression):
21+
return expression.name
1422
else:
1523
return "?column?"

abstra_json_sql/parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
SelectField,
77
Limit,
88
IntExpression,
9-
SelectWildcard,
9+
WildcardExpression,
1010
FunctionCallExpression,
1111
From,
1212
NameExpression,
@@ -184,7 +184,7 @@ def parse_fields(tokens: List[Token]) -> Tuple[List[SelectField], List[Token]]:
184184
if tokens[0].type == "keyword" and tokens[0].value.upper() == "FROM":
185185
break
186186
if tokens[0].type == "wildcard":
187-
fields.append(SelectWildcard())
187+
fields.append(WildcardExpression())
188188
tokens = tokens[1:]
189189
else:
190190
exp, tokens = parse_expression(tokens)

abstra_json_sql/parser_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .ast import (
55
Select,
66
From,
7-
SelectWildcard,
7+
WildcardExpression,
88
PlusExpression,
99
FunctionCallExpression,
1010
SelectField,
@@ -37,7 +37,7 @@ def test_select_wildcard(self):
3737
self.assertEqual(
3838
ast,
3939
Select(
40-
field_parts=[SelectWildcard()],
40+
field_parts=[WildcardExpression()],
4141
from_part=From(
4242
table="users",
4343
),

0 commit comments

Comments
 (0)