Skip to content

Commit 4dcbd11

Browse files
committed
join
1 parent dfea28f commit 4dcbd11

4 files changed

Lines changed: 94 additions & 33 deletions

File tree

abstra_json_sql/apply.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,11 @@ def apply_select_fields(fields: List[SelectField], data: List[dict], ctx: dict):
511511
result_row = {}
512512
for field in fields:
513513
if isinstance(field.expression, Wildcard):
514-
result_row.update(row)
514+
for key, value in row.items():
515+
assert isinstance(key, str), "Key should be a string"
516+
parts = key.split(".")
517+
last_part = parts[-1]
518+
result_row[last_part] = value
515519
elif isinstance(field.expression, Expression):
516520
result_row[field_name(field)] = apply_expression(
517521
field.expression, {**ctx, **row}
@@ -522,6 +526,11 @@ def apply_select_fields(fields: List[SelectField], data: List[dict], ctx: dict):
522526
return result
523527

524528

529+
def add_scope_to_keys(prefix: str, data: dict) -> dict:
530+
change = {f"{prefix}.{key}": value for key, value in data.items()}
531+
return {**data, **change}
532+
533+
525534
def apply_from(
526535
from_part: Optional[From], tables: ITablesSnapshot, ctx: dict
527536
) -> List[dict]:
@@ -537,11 +546,23 @@ def apply_from(
537546
if not join_table:
538547
raise ValueError(f"Table {join.table} not found")
539548
data = [
540-
{**row, **join_row}
549+
{
550+
**add_scope_to_keys(table.name, row),
551+
**add_scope_to_keys(join_table.name, join_row),
552+
}
541553
for row in data
542554
for join_row in join_table.data
543-
if apply_expression(join.on, {**ctx, **row, **join_row})
555+
if apply_expression(
556+
join.on,
557+
{
558+
**ctx,
559+
**add_scope_to_keys(table.name, row),
560+
**add_scope_to_keys(join_table.name, join_row),
561+
},
562+
)
544563
]
564+
else:
565+
data = [{**add_scope_to_keys(table.name, row)} for row in data]
545566
return data
546567

547568

abstra_json_sql/eval_test.py

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,72 @@ def test_group_by(self):
467467
],
468468
)
469469

470+
def test_select_wildcard(self):
471+
code = "select * from bar"
472+
tables = InMemoryTables(
473+
tables=[
474+
Table(
475+
name="bar",
476+
columns=[Column(name="foo", type="text")],
477+
data=[
478+
{"foo": "a"},
479+
{"foo": "b"},
480+
{"foo": None},
481+
{"foo": "c"},
482+
],
483+
)
484+
],
485+
)
486+
ctx = {}
487+
result = eval_sql(code=code, tables=tables, ctx=ctx)
488+
self.assertEqual(
489+
result,
490+
[
491+
{"foo": "a"},
492+
{"foo": "b"},
493+
{"foo": None},
494+
{"foo": "c"},
495+
],
496+
)
497+
498+
def test_join(self):
499+
code = "select a.foo, b.bar from a join b on a.id = b.a_id"
500+
tables = InMemoryTables(
501+
tables=[
502+
Table(
503+
name="a",
504+
columns=[
505+
Column(name="id", type="int"),
506+
Column(name="foo", type="text"),
507+
],
508+
data=[
509+
{"id": 1, "foo": "a1"},
510+
{"id": 2, "foo": "a2"},
511+
],
512+
),
513+
Table(
514+
name="b",
515+
columns=[
516+
Column(name="a_id", type="int"),
517+
Column(name="bar", type="text"),
518+
],
519+
data=[
520+
{"a_id": 1, "bar": "b1"},
521+
{"a_id": 2, "bar": "b2"},
522+
],
523+
),
524+
],
525+
)
526+
ctx = {}
527+
result = eval_sql(code=code, tables=tables, ctx=ctx)
528+
self.assertEqual(
529+
result,
530+
[
531+
{"foo": "a1", "bar": "b1"},
532+
{"foo": "a2", "bar": "b2"},
533+
],
534+
)
535+
470536
def test_complete(self):
471537
code = "\n".join(
472538
[
@@ -499,31 +565,3 @@ def test_complete(self):
499565
ctx = {}
500566
result = eval_sql(code=code, tables=tables, ctx=ctx)
501567
self.assertEqual(result, [{"foo": 3, "count": 2}])
502-
503-
def test_select_wildcard(self):
504-
code = "select * from bar"
505-
tables = InMemoryTables(
506-
tables=[
507-
Table(
508-
name="bar",
509-
columns=[Column(name="foo", type="text")],
510-
data=[
511-
{"foo": "a"},
512-
{"foo": "b"},
513-
{"foo": None},
514-
{"foo": "c"},
515-
],
516-
)
517-
],
518-
)
519-
ctx = {}
520-
result = eval_sql(code=code, tables=tables, ctx=ctx)
521-
self.assertEqual(
522-
result,
523-
[
524-
{"foo": "a"},
525-
{"foo": "b"},
526-
{"foo": None},
527-
{"foo": "c"},
528-
],
529-
)

abstra_json_sql/field_name.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def expression_name(expression: Expression):
1616
Get the field name from an Expression object.
1717
"""
1818
if isinstance(expression, NameExpression):
19-
return expression.name
19+
name_parts = expression.name.split(".")
20+
last_part = name_parts[-1]
21+
return last_part
2022
elif isinstance(expression, FunctionCallExpression):
2123
return expression.name
2224
else:

abstra_json_sql/lexer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def extract_space(code: str) -> str:
3232

3333

3434
def start_with_name(code: str):
35-
return code[0].isalnum() or code[0] == "_"
35+
return code[0].isalnum() or code[0] == "_" or code[0] == "."
3636

3737

3838
def extract_name(code: str):

0 commit comments

Comments
 (0)