Skip to content

Commit f309a46

Browse files
committed
having limit offset
1 parent 07dcc62 commit f309a46

7 files changed

Lines changed: 519 additions & 35 deletions

File tree

abstra_json_sql/apply.py

Lines changed: 103 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
FunctionCallExpression,
1616
Command,
1717
PlusExpression,
18+
NullExpression,
19+
NotExpression,
20+
AndExpression,
21+
OrExpression,
22+
IsExpression,
23+
FalseExpression,
24+
TrueExpression,
25+
Wildcard,
1826
OrderBy,
1927
MinusExpression,
2028
MultiplyExpression,
@@ -28,6 +36,7 @@
2836
Wildcard,
2937
Limit,
3038
)
39+
from dataclasses import dataclass
3140

3241

3342
def is_aggregate_function(name: str) -> bool:
@@ -344,6 +353,41 @@ def apply_expression(expression: Expression, ctx: dict):
344353
return args[0].upper()
345354
else:
346355
raise ValueError(f"Unknown function: {expression.name}")
356+
elif isinstance(expression, NullExpression):
357+
return None
358+
elif isinstance(expression, IsExpression):
359+
left_value = apply_expression(expression.left, ctx)
360+
right_value = apply_expression(expression.right, ctx)
361+
if expression.is_not:
362+
return left_value is not right_value
363+
else:
364+
return left_value is right_value
365+
elif isinstance(expression, NotExpression):
366+
value = apply_expression(expression.expression, ctx)
367+
if isinstance(value, bool):
368+
return not value
369+
else:
370+
raise ValueError(f"Not expression should return bool, not {value}")
371+
elif isinstance(expression, AndExpression):
372+
left_value = apply_expression(expression.left, ctx)
373+
right_value = apply_expression(expression.right, ctx)
374+
if isinstance(left_value, bool) and isinstance(right_value, bool):
375+
return left_value and right_value
376+
else:
377+
raise ValueError(
378+
f"Unsupported types for AND: {type(left_value)}, {type(right_value)}"
379+
)
380+
elif isinstance(expression, OrExpression):
381+
left_value = apply_expression(expression.left, ctx)
382+
right_value = apply_expression(expression.right, ctx)
383+
if isinstance(left_value, bool) and isinstance(right_value, bool):
384+
return left_value or right_value
385+
else:
386+
raise ValueError(
387+
f"Unsupported types for OR: {type(left_value)}, {type(right_value)}"
388+
)
389+
elif isinstance(expression, Wildcard):
390+
raise ValueError("Wildcard cannot be used in expressions")
347391
else:
348392
raise ValueError(f"Unsupported expression type: {type(expression)}")
349393

@@ -361,13 +405,62 @@ def apply_where(where: Where, data: List[dict], ctx: dict):
361405
return result
362406

363407

408+
def apply_having(having: Where, data: List[dict], ctx: dict):
409+
result = []
410+
for row in data:
411+
value = apply_expression(having.expression, {**ctx, **row})
412+
if value is True:
413+
result.append(row)
414+
elif value is False:
415+
continue
416+
else:
417+
raise ValueError(f"Having expressions should return bool, not {value}")
418+
return result
419+
420+
364421
def apply_order_by(order_by: OrderBy, data: List[dict], ctx: dict):
365-
for order_field in order_by.fields:
366-
data.sort(
367-
key=lambda x: apply_expression(order_field.expression, {**ctx, **x}),
368-
reverse=(order_field.direction == "DESC"),
369-
)
370-
return data
422+
@dataclass
423+
class Comparable:
424+
value: list
425+
426+
def __lt__(self, other):
427+
assert isinstance(other, Comparable), "other should be Comparable"
428+
for i in range(len(self.value)):
429+
if self.value[i] is None:
430+
return True
431+
elif other.value[i] is None:
432+
return False
433+
if self.value[i] < other.value[i]:
434+
return True
435+
elif self.value[i] > other.value[i]:
436+
return False
437+
return False
438+
439+
def __gt__(self, other):
440+
assert isinstance(other, Comparable), "other should be Comparable"
441+
for i in range(len(self.value)):
442+
if self.value[i] is None:
443+
return False
444+
elif other.value[i] is None:
445+
return True
446+
if self.value[i] > other.value[i]:
447+
return True
448+
elif self.value[i] < other.value[i]:
449+
return False
450+
return False
451+
452+
sorted_data = sorted(
453+
data,
454+
key=lambda row: Comparable(
455+
[
456+
apply_expression(field.expression, {**ctx, **row})
457+
for field in order_by.fields
458+
]
459+
),
460+
reverse=any(field.direction == "DESC" for field in order_by.fields),
461+
)
462+
463+
return sorted_data
371464

372465

373466
def apply_group_by(group_by: GroupBy, data: List[dict], ctx: dict):
@@ -379,7 +472,7 @@ def apply_group_by(group_by: GroupBy, data: List[dict], ctx: dict):
379472
if key not in groups:
380473
groups[key] = []
381474
groups[key].append(row)
382-
if not groups:
475+
if not group_by.fields:
383476
return [
384477
{
385478
"__grouped_rows": data,
@@ -413,7 +506,7 @@ def apply_select_fields(fields: List[SelectField], data: List[dict], ctx: dict):
413506
return [
414507
{
415508
field_name(field) or field.expression: apply_expression(
416-
field.expression, {**ctx, "__grouped_rows": data, **row}
509+
field.expression, {**ctx, **row}
417510
)
418511
for field in fields
419512
}
@@ -460,6 +553,8 @@ def apply_select(select: Select, tables: ITablesSnapshot, ctx: dict):
460553
data = apply_group_by(select.group_part, data, ctx)
461554
elif has_implicit_aggregation(select.field_parts):
462555
data = apply_group_by(GroupBy(fields=[]), data, ctx)
556+
if select.having_part:
557+
data = apply_having(select.having_part, data, ctx)
463558
if select.order_part:
464559
data = apply_order_by(select.order_part, data, ctx)
465560
if select.limit_part:

abstra_json_sql/ast.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,45 @@ class FloatExpression(Expression):
4141
value: float
4242

4343

44+
@dataclass
45+
class AndExpression(Expression):
46+
left: Expression
47+
right: Expression
48+
49+
50+
@dataclass
51+
class OrExpression(Expression):
52+
left: Expression
53+
right: Expression
54+
55+
56+
@dataclass
57+
class FalseExpression(Expression):
58+
pass
59+
60+
61+
@dataclass
62+
class TrueExpression(Expression):
63+
pass
64+
65+
66+
@dataclass
67+
class NotExpression(Expression):
68+
expression: Expression
69+
70+
71+
@dataclass
72+
class IsExpression(Expression):
73+
left: Expression
74+
right: Expression
75+
is_not: bool = False
76+
77+
78+
@dataclass
79+
class NullExpression(Expression):
80+
pass
81+
82+
4483
@dataclass
4584
class PlusExpression(Expression):
4685
left: Expression
@@ -126,6 +165,11 @@ class Where(Ast):
126165
expression: Expression
127166

128167

168+
@dataclass
169+
class Having(Ast):
170+
expression: Expression
171+
172+
129173
@dataclass
130174
class SelectField(Ast):
131175
expression: Expression
@@ -159,6 +203,7 @@ class Select(Command):
159203
field_parts: List[Union[SelectField, Wildcard]]
160204
from_part: Optional[From] = None
161205
where_part: Optional[Where] = None
162-
order_part: Optional[OrderBy] = None
163206
group_part: Optional[GroupBy] = None
207+
having_part: Optional[Where] = None
208+
order_part: Optional[OrderBy] = None
164209
limit_part: Optional[Limit] = None

abstra_json_sql/eval_test.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,169 @@ def test_string_agg(self):
333333
ctx = {}
334334
result = eval_sql(code=code, tables=tables, ctx=ctx)
335335
self.assertEqual(result, [{"string_agg": "a,b,c"}])
336+
337+
def test_limit(self):
338+
code = "select foo from bar limit 1"
339+
tables = InMemoryTables(
340+
tables=[
341+
Table(
342+
name="bar",
343+
columns=[Column(name="foo", type="text")],
344+
data=[
345+
{"foo": "a"},
346+
{"foo": "b"},
347+
{"foo": None},
348+
{"foo": "c"},
349+
],
350+
)
351+
],
352+
)
353+
ctx = {}
354+
result = eval_sql(code=code, tables=tables, ctx=ctx)
355+
self.assertEqual(result, [{"foo": "a"}])
356+
357+
def test_limit_offset(self):
358+
code = "select foo from bar limit 1 offset 1"
359+
tables = InMemoryTables(
360+
tables=[
361+
Table(
362+
name="bar",
363+
columns=[Column(name="foo", type="text")],
364+
data=[
365+
{"foo": "a"},
366+
{"foo": "b"},
367+
{"foo": None},
368+
{"foo": "c"},
369+
],
370+
)
371+
],
372+
)
373+
ctx = {}
374+
result = eval_sql(code=code, tables=tables, ctx=ctx)
375+
self.assertEqual(result, [{"foo": "b"}])
376+
377+
def test_order_by(self):
378+
code = "select foo from bar order by foo"
379+
tables = InMemoryTables(
380+
tables=[
381+
Table(
382+
name="bar",
383+
columns=[Column(name="foo", type="text")],
384+
data=[
385+
{"foo": "c"},
386+
{"foo": "b"},
387+
{"foo": None},
388+
{"foo": "a"},
389+
],
390+
)
391+
],
392+
)
393+
ctx = {}
394+
result = eval_sql(code=code, tables=tables, ctx=ctx)
395+
self.assertEqual(
396+
result, [{"foo": None}, {"foo": "a"}, {"foo": "b"}, {"foo": "c"}]
397+
)
398+
399+
def test_order_by_desc(self):
400+
code = "select foo from bar order by foo desc"
401+
tables = InMemoryTables(
402+
tables=[
403+
Table(
404+
name="bar",
405+
columns=[Column(name="foo", type="text")],
406+
data=[
407+
{"foo": "c"},
408+
{"foo": "b"},
409+
{"foo": None},
410+
{"foo": "a"},
411+
],
412+
)
413+
],
414+
)
415+
ctx = {}
416+
result = eval_sql(code=code, tables=tables, ctx=ctx)
417+
self.assertEqual(
418+
result, [{"foo": "c"}, {"foo": "b"}, {"foo": "a"}, {"foo": None}]
419+
)
420+
421+
def test_order_by_asc(self):
422+
code = "select foo from bar order by foo asc"
423+
tables = InMemoryTables(
424+
tables=[
425+
Table(
426+
name="bar",
427+
columns=[Column(name="foo", type="text")],
428+
data=[
429+
{"foo": "c"},
430+
{"foo": "b"},
431+
{"foo": None},
432+
{"foo": "a"},
433+
],
434+
)
435+
],
436+
)
437+
ctx = {}
438+
result = eval_sql(code=code, tables=tables, ctx=ctx)
439+
self.assertEqual(
440+
result, [{"foo": None}, {"foo": "a"}, {"foo": "b"}, {"foo": "c"}]
441+
)
442+
443+
def test_group_by(self):
444+
code = "select foo, count(*) from bar group by foo"
445+
tables = InMemoryTables(
446+
tables=[
447+
Table(
448+
name="bar",
449+
columns=[Column(name="foo", type="text")],
450+
data=[
451+
{"foo": "a"},
452+
{"foo": "b"},
453+
{"foo": None},
454+
{"foo": "a"},
455+
],
456+
)
457+
],
458+
)
459+
ctx = {}
460+
result = eval_sql(code=code, tables=tables, ctx=ctx)
461+
self.assertEqual(
462+
result,
463+
[
464+
{"foo": "a", "count": 2},
465+
{"foo": "b", "count": 1},
466+
{"foo": None, "count": 1},
467+
],
468+
)
469+
470+
def test_complete(self):
471+
code = "\n".join(
472+
[
473+
"select foo, count(*)",
474+
"from bar as baz",
475+
"where foo is not null",
476+
"group by foo",
477+
"having foo <> 2",
478+
"order by foo",
479+
"limit 1 offset 1",
480+
]
481+
)
482+
tables = InMemoryTables(
483+
tables=[
484+
Table(
485+
name="bar",
486+
columns=[Column(name="foo", type="text")],
487+
data=[
488+
{"foo": 1},
489+
{"foo": 2},
490+
{"foo": 3},
491+
{"foo": 2},
492+
{"foo": None},
493+
{"foo": 3},
494+
{"foo": 1},
495+
],
496+
)
497+
],
498+
)
499+
ctx = {}
500+
result = eval_sql(code=code, tables=tables, ctx=ctx)
501+
self.assertEqual(result, [{"foo": 3, "count": 2}])

0 commit comments

Comments
 (0)