11from typing import List , Dict , Optional
22from .tables import ITablesSnapshot
3- from .field_name import field_name
3+ from .field_name import field_name , expression_name
44from .ast import (
55 Expression ,
66 StringExpression ,
2525 GreaterThanOrEqualExpression ,
2626 LessThanExpression ,
2727 LessThanOrEqualExpression ,
28+ WildcardExpression ,
2829 Limit ,
2930)
3031
@@ -157,14 +158,32 @@ def apply_expression(expression: Expression, ctx: dict):
157158 )
158159 elif isinstance (expression , FunctionCallExpression ):
159160 if is_aggregate_function (expression .name ):
160- raise NotImplementedError (
161- f"Function call expressions are not implemented: { expression .name } "
162- )
161+ if expression .name .lower () == "count" :
162+ assert len (expression .args ) == 1 , "Count function requires one argument"
163+ if isinstance (expression .args [0 ], WildcardExpression ):
164+ return len (ctx ["__grouped_rows" ])
165+ elif isinstance (expression .args [0 ], NameExpression ):
166+ return len (
167+ [
168+ row
169+ for row in ctx ["__grouped_rows" ]
170+ if expression .args [0 ].name in row
171+ and row [expression .args [0 ].name ] is not None
172+ ]
173+ )
174+ else :
175+ raise ValueError (f"Unknown aggregate function: { expression .name } " )
163176 else :
164177 args = [apply_expression (arg , ctx ) for arg in expression .args ]
165178 if expression .name == "lower" :
179+ assert isinstance (
180+ args [0 ], str
181+ ), "lower function requires a string argument"
166182 return args [0 ].lower ()
167183 elif expression .name == "upper" :
184+ assert isinstance (
185+ args [0 ], str
186+ ), "upper function requires a string argument"
168187 return args [0 ].upper ()
169188 else :
170189 raise ValueError (f"Unknown function: { expression .name } " )
@@ -196,14 +215,27 @@ def apply_order_by(order_by: OrderBy, data: List[dict], ctx: dict):
196215
197216def apply_group_by (group_by : GroupBy , data : List [dict ], ctx : dict ):
198217 groups : Dict [tuple , list ] = {}
199- for row in data :
218+ for idx , row in enumerate ( data ) :
200219 key = tuple (
201220 apply_expression (field , {** ctx , ** row }) for field in group_by .fields
202221 )
203222 if key not in groups :
204223 groups [key ] = []
205224 groups [key ].append (row )
206- return groups
225+ if not groups :
226+ return [
227+ {
228+ "__grouped_rows" : data ,
229+ ** {expression_name (field ): None for field in group_by .fields },
230+ }
231+ ]
232+ return [
233+ {
234+ "__grouped_rows" : rows ,
235+ ** {expression_name (field ): key for field , key in zip (group_by .fields , key )},
236+ }
237+ for key , rows in groups .items ()
238+ ]
207239
208240
209241def apply_limit (limit : Limit , data : List [dict ], ctx : dict ):
@@ -212,11 +244,19 @@ def apply_limit(limit: Limit, data: List[dict], ctx: dict):
212244 return data [start :end ]
213245
214246
247+ def has_aggregation_fields (fields : List [SelectField ]) -> bool :
248+ for field in fields :
249+ if isinstance (field .expression , FunctionCallExpression ):
250+ if is_aggregate_function (field .expression .name ):
251+ return True
252+ return False
253+
254+
215255def apply_select_fields (fields : List [SelectField ], data : List [dict ], ctx : dict ):
216256 return [
217257 {
218258 field_name (field ) or field .expression : apply_expression (
219- field .expression , {** ctx , ** row }
259+ field .expression , {** ctx , "__grouped_rows" : data , ** row }
220260 )
221261 for field in fields
222262 }
@@ -247,12 +287,22 @@ def apply_from(
247287 return data
248288
249289
290+ def has_implicit_aggregation (fields : List [SelectField ]) -> bool :
291+ for field in fields :
292+ if isinstance (field .expression , FunctionCallExpression ):
293+ if is_aggregate_function (field .expression .name ):
294+ return True
295+ return False
296+
297+
250298def apply_select (select : Select , tables : ITablesSnapshot , ctx : dict ):
251299 data = apply_from (select .from_part , tables , ctx )
252300 if select .where_part :
253301 data = apply_where (select .where_part , data , ctx )
254302 if select .group_part :
255303 data = apply_group_by (select .group_part , data , ctx )
304+ elif has_implicit_aggregation (select .field_parts ):
305+ data = apply_group_by (GroupBy (fields = []), data , ctx )
256306 if select .order_part :
257307 data = apply_order_by (select .order_part , data , ctx )
258308 if select .limit_part :
0 commit comments