Skip to content

Commit 0972a3e

Browse files
committed
ENH: Lock in aggregation functions
1 parent fe2e55e commit 0972a3e

4 files changed

Lines changed: 34 additions & 43 deletions

File tree

dataframe_sql/grammar/sql.grammar

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ JOIN_TYPE.5: "INNER"i | "CROSS"i | /FULL\sOUTER/i | /LEFT\sOUTER/i | /RIGHT\sOU
4242
| "CASE"i (when_then)+ "ELSE"i expression_math "END"i -> case_expression
4343
| "CAST"i "(" expression_math "AS"i TYPENAME ")" -> as_type
4444
| "CAST"i "(" literal "AS"i TYPENAME ")" -> literal_cast
45-
| aggregation "(" expression_math ")" -> sql_aggregation
45+
| AGGREGATION "(" expression_math ")" -> sql_aggregation
4646
| "RANK"i "(" ")" rank_form -> rank_expression
4747
| "DENSE_RANK"i "(" ")" rank_form -> dense_rank_expression
4848

@@ -54,9 +54,12 @@ order: expression_math ["ASC"i] -> order_asc
5454
| expression_math "DESC"i -> order_desc
5555

5656
column_name: [NAME "."] NAME
57-
?expression_product: expression
58-
| expression_product "*" expression -> expression_mul
59-
| expression_product "/" expression -> expression_div
57+
?expression_product: expression_parens
58+
| expression_product "*" expression_parens -> expression_mul
59+
| expression_product "/" expression_parens -> expression_div
60+
61+
?expression_parens: expression
62+
| "(" expression_parens "*" expression ")"
6063

6164
?expression: [NAME "."] (NAME | STAR) -> column_name
6265
| literal
@@ -79,7 +82,7 @@ TYPENAME: "object"i
7982
| "datetime64"i
8083
| "timestamp"i
8184
| "category"i
82-
?aggregation: NAME -> aggregation_name
85+
AGGREGATION.8: "sum"i | "avg"i | "min"i | "max"i
8386
alias: NAME -> alias_string
8487
_window_name: NAME
8588
limit_count: integer -> limit_count

dataframe_sql/parsing/sql_parser.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -554,14 +554,6 @@ def where_expr(self, truth_value_dataframe):
554554
"""
555555
return Token("where_expr", truth_value_dataframe[0])
556556

557-
def aggregation_name(self, aggregation_name):
558-
"""
559-
Returns the function name tree
560-
:param aggregation_name:
561-
:return:
562-
"""
563-
return aggregation_name[0].value
564-
565557
def alias_string(self, name: List[str]):
566558
"""
567559
Returns an alias token_or_tree with the name extracted
@@ -918,21 +910,14 @@ def aggregate(self, function_name_list_form):
918910
"""
919911
return "".join(function_name_list_form)
920912

921-
def aggregation_name(self, tokens):
922-
"""
923-
Extracts function name from token_or_tree
924-
:param tokens:
925-
:return:
926-
"""
927-
return tokens[0].value
928-
929-
def sql_aggregation(self, aggregation_expr):
913+
def sql_aggregation(self, aggregation_expr: list):
930914
"""
931915
Handles presence of aggregation in an sql_object
932916
:param aggregation_expr: Function sql_object
933917
:return:
934918
"""
935-
aggregate_name = aggregation_expr[0]
919+
aggregate_token: Token = aggregation_expr[0]
920+
aggregate_name: str = aggregate_token.value
936921
column = aggregation_expr[1]
937922
table = self.dataframe_map[column.table]
938923
column_true_name = self.column_name_map[column.table][column.name]

dataframe_sql/tests/pandas_sql_functionality_test.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,29 +1410,29 @@ def test_order_of_operations_no_parens():
14101410
tm.assert_frame_equal(pandas_frame, my_frame)
14111411

14121412

1413-
# def test_order_of_operations_with_parens():
1414-
# """
1415-
# Test math parentheses
1416-
# :return:
1417-
# """
1418-
#
1419-
# my_frame = query(
1420-
# "select 20 * (avocado_id + 3) / (20 + avocado_id) as my_math from " "avocado"
1421-
# )
1422-
#
1423-
# pandas_frame = AVOCADO.copy()[["avocado_id"]]
1424-
# pandas_frame["my_math"] = (
1425-
# 20 * (pandas_frame["avocado_id"] + 3) / (20 + pandas_frame["avocado_id"])
1426-
# )
1427-
#
1428-
# pandas_frame = pandas_frame.drop(columns=["avocado_id"])
1429-
#
1430-
# tm.assert_frame_equal(pandas_frame, my_frame)
1413+
def test_order_of_operations_with_parens():
1414+
"""
1415+
Test math parentheses
1416+
:return:
1417+
"""
1418+
1419+
my_frame = query(
1420+
"select 20 * (avocado_id + 3) / (20 + avocado_id) as my_math from " "avocado"
1421+
)
1422+
1423+
pandas_frame = AVOCADO.copy()[["avocado_id"]]
1424+
pandas_frame["my_math"] = (
1425+
20 * (pandas_frame["avocado_id"] + 3) / (20 + pandas_frame["avocado_id"])
1426+
)
1427+
1428+
pandas_frame = pandas_frame.drop(columns=["avocado_id"])
1429+
1430+
tm.assert_frame_equal(pandas_frame, my_frame)
14311431

14321432

14331433
if __name__ == "__main__":
14341434
register_env_tables()
14351435

1436-
# test_order_of_operations_with_parens()
1436+
test_sum()
14371437

14381438
remove_env_tables()

dataframe_sql/tests/sql_execution_plan_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,9 @@ def test_having_one_condition():
462462
"select min(temp) from forest_fires having min(temp) > 2",
463463
show_execution_plan=True,
464464
)
465+
466+
print(plan)
467+
465468
assert (
466469
plan == "FOREST_FIRES.loc[:, ['temp']].assign(__=1).groupby(['__'])"
467470
".agg(**{'_col0': ('temp', 'min')}).reset_index(drop=True)"
@@ -1087,6 +1090,6 @@ def test_timestamps():
10871090
if __name__ == "__main__":
10881091
register_env_tables()
10891092

1090-
test_using_math()
1093+
test_having_one_condition()
10911094

10921095
remove_env_tables()

0 commit comments

Comments
 (0)