Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions sqlmesh/core/macros.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import ast
import inspect
import sys
import types
Expand Down Expand Up @@ -1453,10 +1454,7 @@ def call_macro(
bound = sig.bind(*provided_args, **provided_kwargs)
bound.apply_defaults()

try:
annotations = t.get_type_hints(func, localns=get_supported_types())
except (NameError, TypeError): # forward references aren't handled
annotations = {}
annotations = _resolve_macro_annotations(func)

# If the macro is annotated, we try coerce the actual parameters to the corresponding types
if annotations:
Expand All @@ -1478,6 +1476,45 @@ def call_macro(
return func(*bound.args, **bound.kwargs)


def _resolve_macro_annotations(func: t.Callable) -> t.Dict[str, t.Any]:
annotations: t.Dict[str, t.Any] = {}
namespace = {**get_supported_types(), **func.__globals__}

for name, annotation in inspect.get_annotations(func, eval_str=False).items():
try:
if isinstance(annotation, str):
annotation = _eval_macro_annotation(annotation, namespace)
annotations[name] = annotation
except (AttributeError, NameError, SyntaxError, TypeError, ValueError):
continue

return annotations


def _eval_macro_annotation(annotation: str, namespace: t.Dict[str, t.Any]) -> t.Any:
expr = ast.parse(annotation, mode="eval")

for node in ast.walk(expr):
if not isinstance(
node,
(
ast.Expression,
ast.Attribute,
ast.BinOp,
ast.Constant,
ast.List,
ast.Load,
ast.Name,
ast.Subscript,
ast.Tuple,
ast.BitOr,
),
):
raise ValueError(f"Unsupported annotation expression: {annotation}")

return eval(compile(expr, "<sqlmesh annotation>", "eval"), namespace)


def _coerce(
expr: t.Any,
typ: t.Any,
Expand Down
30 changes: 29 additions & 1 deletion sqlmesh/utils/metaprogramming.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,15 @@ def func_globals(func: t.Callable) -> t.Dict[str, t.Any]:

func_args = next(node for node in ast.walk(root_node) if isinstance(node, ast.arguments))
arg_defaults = (d for d in func_args.defaults + func_args.kw_defaults if d is not None)
arg_annotations = _function_annotations(root_node)

# ast.Name corresponds to variable references, such as foo or x.foo. The former is
# represented as Name(id=foo), and the latter as Attribute(value=Name(id=x) attr=foo)
arg_globals = [
n.id for default in arg_defaults for n in ast.walk(default) if isinstance(n, ast.Name)
n.id
for default in chain(arg_defaults, arg_annotations)
for n in ast.walk(default)
if isinstance(n, ast.Name)
]

code = func.__code__
Expand All @@ -114,6 +118,30 @@ def func_globals(func: t.Callable) -> t.Dict[str, t.Any]:
return variables


def _function_annotations(root_node: ast.Module) -> t.List[ast.expr]:
func_args = next(node for node in ast.walk(root_node) if isinstance(node, ast.arguments))
func_def = next(
node for node in ast.walk(root_node) if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
)

annotations = [
arg.annotation
for arg in (
func_args.posonlyargs
+ func_args.args
+ func_args.kwonlyargs
+ ([func_args.vararg] if func_args.vararg else [])
+ ([func_args.kwarg] if func_args.kwarg else [])
)
if arg and arg.annotation is not None
]

if func_def.returns is not None:
annotations.append(func_def.returns)

return annotations


class ClassFoundException(Exception):
pass

Expand Down
110 changes: 110 additions & 0 deletions tests/core/test_macros.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import typing as t
from datetime import datetime, date
import importlib.util
import sys
import textwrap
from pathlib import Path

import pytest
from sqlglot import MappingSchema, ParseError, exp, parse_one

from sqlmesh import SQL as SQLType
from sqlmesh.core import constants as c, dialect as d
from sqlmesh.core.dialect import StagedFilePath
from sqlmesh.core.macros import SQL, MacroEvalError, MacroEvaluator, macro
from sqlmesh.utils.date import to_datetime, to_date
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.utils.metaprogramming import Executable
from sqlmesh.utils.metaprogramming import build_env, serialize_env
from sqlmesh.core.macros import RuntimeStage


Expand Down Expand Up @@ -108,6 +114,64 @@ def test_literal_type(evaluator, a: t.Literal["test_literal_a", "test_literal_b"
)


def load_alias_macro_module(tmp_path: Path):
module_path = tmp_path / "macro_alias_module.py"
module_path.write_text(
textwrap.dedent(
"""
from __future__ import annotations

from sqlmesh import SQL as SQLType, macro
from sqlmesh.core.macros import MacroEvaluator, SQL

@macro()
def plain_sql_macro(
evaluator: MacroEvaluator,
driver_property: str,
property_name: str,
fallback_value: SQL,
) -> SQL:
assert isinstance(driver_property, str)
assert isinstance(property_name, str)
assert isinstance(fallback_value, str)
return fallback_value

@macro()
def alias_sql_macro(
evaluator: MacroEvaluator,
driver_property: str,
property_name: str,
fallback_value: SQLType,
) -> SQLType:
assert isinstance(driver_property, str)
assert isinstance(property_name, str)
assert isinstance(fallback_value, str)
return fallback_value

@macro()
def partial_resolution_macro(
evaluator: MacroEvaluator,
good_value: str,
fallback_value: SQLType,
intentionally_unresolved: MissingAlias,
) -> SQLType:
assert isinstance(good_value, str)
assert isinstance(fallback_value, str)
return fallback_value
"""
)
)

spec = importlib.util.spec_from_file_location("macro_alias_module", module_path)
assert spec and spec.loader

module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)

return module, spec.name


def test_star(assert_exp_eq) -> None:
sql = """SELECT @STAR(foo) FROM foo"""
expected_sql = "SELECT CAST([foo].[a] AS DATETIMEOFFSET) AS [a], CAST([foo].[b] AS INTEGER) AS [b] FROM foo"
Expand Down Expand Up @@ -220,6 +284,52 @@ def test_case(macro_evaluator: MacroEvaluator) -> None:
assert macro.get_registry()["upper"]


def test_macro_type_annotation_aliases(assert_exp_eq, tmp_path: Path) -> None:
module, module_name = load_alias_macro_module(tmp_path)
env: dict[str, t.Any] = {}

try:
for macro_name in ("plain_sql_macro", "alias_sql_macro", "partial_resolution_macro"):
build_env(getattr(module, macro_name), env=env, name=macro_name, path=tmp_path)

evaluator = MacroEvaluator("hive", python_env=serialize_env(env, path=tmp_path))

assert_exp_eq(
evaluator.transform(
parse_one(
"SELECT @plain_sql_macro('entitlement_sku', 'seat_id', es.seat_id) FROM foo es",
read="hive",
)
),
"SELECT es.seat_id FROM foo es",
dialect="hive",
)
assert_exp_eq(
evaluator.transform(
parse_one(
"SELECT @alias_sql_macro('entitlement_sku', 'seat_id', es.seat_id) FROM foo es",
read="hive",
)
),
"SELECT es.seat_id FROM foo es",
dialect="hive",
)
assert_exp_eq(
evaluator.transform(
parse_one(
"SELECT @partial_resolution_macro('entitlement_sku', 'seat_id', es.seat_id, 1) FROM foo es",
read="hive",
)
),
"SELECT es.seat_id FROM foo es",
dialect="hive",
)
finally:
for macro_name in ("plain_sql_macro", "alias_sql_macro", "partial_resolution_macro"):
macro.registry().pop(macro_name, None)
sys.modules.pop(module_name, None)


def test_macro_var(macro_evaluator):
expression = parse_one("@x")
for k, v in [
Expand Down
6 changes: 6 additions & 0 deletions tests/utils/test_metaprogramming.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlglot import exp as expressions
from sqlglot.expressions import SQLGLOT_META, to_table
from sqlglot.optimizer.pushdown_projections import SELECT_ALL
from sqlmesh import SQL as SQLType

import tests.utils.test_date as test_date
from sqlmesh.core.dialect import normalize_model_name
Expand Down Expand Up @@ -178,6 +179,10 @@ def macro2() -> str:
return "2"


def annotation_alias_func(value: str, fallback_value: SQLType) -> SQLType:
return fallback_value


def test_func_globals() -> None:
assert func_globals(main_func) == {
"Y": 2,
Expand All @@ -194,6 +199,7 @@ def test_func_globals() -> None:
"function_with_custom_decorator": function_with_custom_decorator,
"SQLGLOT_META": SQLGLOT_META,
}
assert func_globals(annotation_alias_func) == {"SQLType": SQLType}
assert func_globals(other_func) == {
"X": 1,
"W": 0,
Expand Down