Skip to content

Commit 61b1294

Browse files
committed
Adding aggregation functions
1 parent 28f92f9 commit 61b1294

7 files changed

Lines changed: 426 additions & 14 deletions

File tree

abstra_json_sql/apply.py

Lines changed: 162 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,27 @@
2525
GreaterThanOrEqualExpression,
2626
LessThanExpression,
2727
LessThanOrEqualExpression,
28-
WildcardExpression,
28+
Wildcard,
2929
Limit,
3030
)
3131

3232

3333
def is_aggregate_function(name: str) -> bool:
3434
# Placeholder for aggregate function check
35-
return name.lower() in ["sum", "avg", "count", "min", "max"]
35+
return name.lower() in [
36+
"sum",
37+
"avg",
38+
"count",
39+
"min",
40+
"max",
41+
"every",
42+
"bool_or",
43+
"bool_and",
44+
"bit_or",
45+
"bit_and",
46+
"array_agg",
47+
"string_agg",
48+
]
3649

3750

3851
def apply_expression(expression: Expression, ctx: dict):
@@ -160,17 +173,161 @@ def apply_expression(expression: Expression, ctx: dict):
160173
if is_aggregate_function(expression.name):
161174
if expression.name.lower() == "count":
162175
assert len(expression.args) == 1, "Count function requires one argument"
163-
if isinstance(expression.args[0], WildcardExpression):
176+
if isinstance(expression.args[0], Wildcard):
164177
return len(ctx["__grouped_rows"])
165178
elif isinstance(expression.args[0], NameExpression):
166179
return len(
167180
[
168181
row
169182
for row in ctx["__grouped_rows"]
170-
if expression.args[0].name in row
171-
and row[expression.args[0].name] is not None
183+
if apply_expression(expression.args[0], {**ctx, **row})
184+
is not None
185+
and apply_expression(expression.args[0], {**ctx, **row})
186+
is not None
172187
]
173188
)
189+
elif expression.name.lower() == "sum":
190+
assert len(expression.args) == 1, "Sum function requires one argument"
191+
return sum(
192+
apply_expression(expression.args[0], {**ctx, **row})
193+
for row in ctx["__grouped_rows"]
194+
if apply_expression(expression.args[0], {**ctx, **row}) is not None
195+
and isinstance(
196+
apply_expression(expression.args[0], {**ctx, **row}),
197+
(int, float),
198+
)
199+
)
200+
elif expression.name.lower() == "avg":
201+
assert len(expression.args) == 1, "Avg function requires one argument"
202+
values = [
203+
apply_expression(expression.args[0], {**ctx, **row})
204+
for row in ctx["__grouped_rows"]
205+
if apply_expression(expression.args[0], {**ctx, **row}) is not None
206+
and isinstance(
207+
apply_expression(expression.args[0], {**ctx, **row}),
208+
(int, float),
209+
)
210+
]
211+
if not values:
212+
return None
213+
return sum(values) / len(values)
214+
elif expression.name.lower() == "min":
215+
assert len(expression.args) == 1, "Min function requires one argument"
216+
values = [
217+
apply_expression(expression.args[0], {**ctx, **row})
218+
for row in ctx["__grouped_rows"]
219+
if apply_expression(expression.args[0], {**ctx, **row}) is not None
220+
and isinstance(
221+
apply_expression(expression.args[0], {**ctx, **row}),
222+
(int, float),
223+
)
224+
]
225+
if not values:
226+
return None
227+
return min(values)
228+
elif expression.name.lower() == "max":
229+
assert len(expression.args) == 1, "Max function requires one argument"
230+
values = [
231+
apply_expression(expression.args[0], {**ctx, **row})
232+
for row in ctx["__grouped_rows"]
233+
if apply_expression(expression.args[0], {**ctx, **row}) is not None
234+
and isinstance(
235+
apply_expression(expression.args[0], {**ctx, **row}),
236+
(int, float),
237+
)
238+
]
239+
if not values:
240+
return None
241+
return max(values)
242+
elif expression.name.lower() == "every":
243+
assert len(expression.args) == 1, "Every function requires one argument"
244+
return all(
245+
apply_expression(expression.args[0], {**ctx, **row})
246+
for row in ctx["__grouped_rows"]
247+
if apply_expression(expression.args[0], {**ctx, **row}) is not None
248+
and isinstance(
249+
apply_expression(expression.args[0], {**ctx, **row}), bool
250+
)
251+
)
252+
elif expression.name.lower() == "bool_or":
253+
assert (
254+
len(expression.args) == 1
255+
), "Bool_or function requires one argument"
256+
return any(
257+
apply_expression(expression.args[0], {**ctx, **row})
258+
for row in ctx["__grouped_rows"]
259+
if apply_expression(expression.args[0], {**ctx, **row}) is not None
260+
and isinstance(
261+
apply_expression(expression.args[0], {**ctx, **row}), bool
262+
)
263+
)
264+
elif expression.name.lower() == "bool_and":
265+
assert (
266+
len(expression.args) == 1
267+
), "Bool_and function requires one argument"
268+
return all(
269+
apply_expression(expression.args[0], {**ctx, **row})
270+
for row in ctx["__grouped_rows"]
271+
if apply_expression(expression.args[0], {**ctx, **row}) is not None
272+
and isinstance(
273+
apply_expression(expression.args[0], {**ctx, **row}), bool
274+
)
275+
)
276+
elif expression.name.lower() == "bit_or":
277+
assert (
278+
len(expression.args) == 1
279+
), "Bit_or function requires one argument"
280+
not_null_rows = [
281+
row
282+
for row in ctx["__grouped_rows"]
283+
if apply_expression(expression.args[0], {**ctx, **row}) is not None
284+
and apply_expression(expression.args[0], {**ctx, **row}) is not None
285+
]
286+
if len(not_null_rows) == 0:
287+
return None
288+
result_bits = apply_expression(
289+
expression.args[0], {**ctx, **not_null_rows[0]}
290+
)
291+
for row in not_null_rows[1:]:
292+
result_bits |= apply_expression(expression.args[0], {**ctx, **row})
293+
return result_bits
294+
elif expression.name.lower() == "bit_and":
295+
assert (
296+
len(expression.args) == 1
297+
), "Bit_and function requires one argument"
298+
not_null_rows = [
299+
row
300+
for row in ctx["__grouped_rows"]
301+
if apply_expression(expression.args[0], {**ctx, **row}) is not None
302+
and apply_expression(expression.args[0], {**ctx, **row}) is not None
303+
]
304+
if len(not_null_rows) == 0:
305+
return None
306+
result_bits = apply_expression(
307+
expression.args[0], {**ctx, **not_null_rows[0]}
308+
)
309+
for row in not_null_rows[1:]:
310+
result_bits &= apply_expression(expression.args[0], {**ctx, **row})
311+
return result_bits
312+
elif expression.name.lower() == "array_agg":
313+
assert (
314+
len(expression.args) == 1
315+
), "Array_agg function requires one argument"
316+
return [
317+
apply_expression(expression.args[0], {**ctx, **row})
318+
for row in ctx["__grouped_rows"]
319+
]
320+
elif expression.name.lower() == "string_agg":
321+
assert (
322+
len(expression.args) == 2
323+
), "String_agg function requires two arguments"
324+
separator = expression.args[1].value
325+
return separator.join(
326+
str(apply_expression(expression.args[0], {**ctx, **row}))
327+
for row in ctx["__grouped_rows"]
328+
if apply_expression(expression.args[0], {**ctx, **row}) is not None
329+
and apply_expression(expression.args[0], {**ctx, **row}) is not None
330+
)
174331
else:
175332
raise ValueError(f"Unknown aggregate function: {expression.name}")
176333
else:

abstra_json_sql/ast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class NotEqualExpression(Expression):
102102

103103

104104
@dataclass
105-
class WildcardExpression(Expression):
105+
class Wildcard(Ast):
106106
pass
107107

108108

@@ -156,7 +156,7 @@ class Limit(Ast):
156156

157157
@dataclass
158158
class Select(Command):
159-
field_parts: List[Union[SelectField, WildcardExpression]]
159+
field_parts: List[Union[SelectField, Wildcard]]
160160
from_part: Optional[From] = None
161161
where_part: Optional[Where] = None
162162
order_part: Optional[OrderBy] = None

0 commit comments

Comments
 (0)