Skip to content

Commit 8816edc

Browse files
committed
Fix macro func variable extraction & add tests
1 parent e810b39 commit 8816edc

2 files changed

Lines changed: 33 additions & 8 deletions

File tree

sqlmesh/core/model/common.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,22 @@ def make_python_env(
160160
def _extract_macro_func_variable_references(macro_func: exp.Expression) -> t.Set[str]:
161161
references = set()
162162

163-
for n in macro_func.walk():
164-
if n is macro_func:
165-
continue
163+
# Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
164+
# they will be handled in a separate call of _extract_macro_func_variable_references.
165+
def _prune_nested_macro_func(expression: exp.Expression) -> bool:
166+
return (
167+
type(n) is d.MacroFunc
168+
and n is not macro_func
169+
and n.this.name.lower() not in (c.VAR, c.BLUEPRINT_VAR)
170+
)
166171

167-
# Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
168-
# they will be handled in a separate call of _extract_macro_func_variable_references.
169-
if isinstance(n, d.MacroFunc):
172+
for n in macro_func.walk(prune=_prune_nested_macro_func):
173+
if type(n) is d.MacroFunc:
170174
this = n.this
171-
if this.name.lower() in (c.VAR, c.BLUEPRINT_VAR) and this.expressions:
172-
references.add(this.expressions[0].this.lower())
175+
args = this.expressions
176+
177+
if this.name.lower() in (c.VAR, c.BLUEPRINT_VAR) and args and args[0].is_string:
178+
references.add(args[0].this.lower())
173179
elif isinstance(n, d.MacroVar):
174180
references.add(n.name.lower())
175181
elif isinstance(n, (exp.Identifier, d.MacroStrReplace, d.MacroSQL)) and "@" in n.name:

tests/core/test_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11246,3 +11246,22 @@ def test_each_macro_with_paren_expression_arg(assert_exp_eq):
1124611246
'value' AS "property1"
1124711247
""",
1124811248
)
11249+
11250+
11251+
@pytest.mark.parametrize(
11252+
"macro_func, variables",
11253+
[
11254+
("@M(@v1)", {"v1"}),
11255+
("@M(@{v1})", {"v1"}),
11256+
("@M(@SQL('@v1'))", {"v1"}),
11257+
("@M(@'@{v1}_foo')", {"v1"}),
11258+
("@M1(@VAR('v1'))", {"v1"}),
11259+
("@M1(@v1, @M2(@v2), @BLUEPRINT_VAR('v3'))", {"v1", "v3"}),
11260+
("@M1(@BLUEPRINT_VAR(@VAR('v1')))", {"v1"}),
11261+
],
11262+
)
11263+
def test_extract_macro_func_variable_references(macro_func: str, variables: t.Set[str]) -> None:
11264+
from sqlmesh.core.model.common import _extract_macro_func_variable_references
11265+
11266+
macro_func_ast = parse_one(macro_func)
11267+
assert _extract_macro_func_variable_references(macro_func_ast) == variables

0 commit comments

Comments
 (0)