diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index c2d32bdecf..cdcdb1d02e 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ast import inspect import sys import types @@ -1453,10 +1454,7 @@ def call_macro( bound = sig.bind(*provided_args, **provided_kwargs) bound.apply_defaults() - try: - annotations = t.get_type_hints(func, localns=get_supported_types()) - except (NameError, TypeError): # forward references aren't handled - annotations = {} + annotations = _resolve_macro_annotations(func) # If the macro is annotated, we try coerce the actual parameters to the corresponding types if annotations: @@ -1478,6 +1476,45 @@ def call_macro( return func(*bound.args, **bound.kwargs) +def _resolve_macro_annotations(func: t.Callable) -> t.Dict[str, t.Any]: + annotations: t.Dict[str, t.Any] = {} + namespace = {**get_supported_types(), **func.__globals__} + + for name, annotation in inspect.get_annotations(func, eval_str=False).items(): + try: + if isinstance(annotation, str): + annotation = _eval_macro_annotation(annotation, namespace) + annotations[name] = annotation + except (AttributeError, NameError, SyntaxError, TypeError, ValueError): + continue + + return annotations + + +def _eval_macro_annotation(annotation: str, namespace: t.Dict[str, t.Any]) -> t.Any: + expr = ast.parse(annotation, mode="eval") + + for node in ast.walk(expr): + if not isinstance( + node, + ( + ast.Expression, + ast.Attribute, + ast.BinOp, + ast.Constant, + ast.List, + ast.Load, + ast.Name, + ast.Subscript, + ast.Tuple, + ast.BitOr, + ), + ): + raise ValueError(f"Unsupported annotation expression: {annotation}") + + return eval(compile(expr, "", "eval"), namespace) + + def _coerce( expr: t.Any, typ: t.Any, diff --git a/sqlmesh/utils/metaprogramming.py b/sqlmesh/utils/metaprogramming.py index cd77c36353..debb2a4287 100644 --- a/sqlmesh/utils/metaprogramming.py +++ b/sqlmesh/utils/metaprogramming.py @@ -93,11 +93,15 @@ def func_globals(func: t.Callable) -> t.Dict[str, t.Any]: func_args = next(node for node in ast.walk(root_node) if isinstance(node, ast.arguments)) arg_defaults = (d for d in func_args.defaults + func_args.kw_defaults if d is not None) + arg_annotations = _function_annotations(root_node) # ast.Name corresponds to variable references, such as foo or x.foo. The former is # represented as Name(id=foo), and the latter as Attribute(value=Name(id=x) attr=foo) arg_globals = [ - n.id for default in arg_defaults for n in ast.walk(default) if isinstance(n, ast.Name) + n.id + for default in chain(arg_defaults, arg_annotations) + for n in ast.walk(default) + if isinstance(n, ast.Name) ] code = func.__code__ @@ -114,6 +118,30 @@ def func_globals(func: t.Callable) -> t.Dict[str, t.Any]: return variables +def _function_annotations(root_node: ast.Module) -> t.List[ast.expr]: + func_args = next(node for node in ast.walk(root_node) if isinstance(node, ast.arguments)) + func_def = next( + node for node in ast.walk(root_node) if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + ) + + annotations = [ + arg.annotation + for arg in ( + func_args.posonlyargs + + func_args.args + + func_args.kwonlyargs + + ([func_args.vararg] if func_args.vararg else []) + + ([func_args.kwarg] if func_args.kwarg else []) + ) + if arg and arg.annotation is not None + ] + + if func_def.returns is not None: + annotations.append(func_def.returns) + + return annotations + + class ClassFoundException(Exception): pass diff --git a/tests/core/test_macros.py b/tests/core/test_macros.py index e37a7ec05b..5156f4321e 100644 --- a/tests/core/test_macros.py +++ b/tests/core/test_macros.py @@ -1,15 +1,21 @@ import typing as t from datetime import datetime, date +import importlib.util +import sys +import textwrap +from pathlib import Path import pytest from sqlglot import MappingSchema, ParseError, exp, parse_one +from sqlmesh import SQL as SQLType from sqlmesh.core import constants as c, dialect as d from sqlmesh.core.dialect import StagedFilePath from sqlmesh.core.macros import SQL, MacroEvalError, MacroEvaluator, macro from sqlmesh.utils.date import to_datetime, to_date from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.metaprogramming import Executable +from sqlmesh.utils.metaprogramming import build_env, serialize_env from sqlmesh.core.macros import RuntimeStage @@ -108,6 +114,64 @@ def test_literal_type(evaluator, a: t.Literal["test_literal_a", "test_literal_b" ) +def load_alias_macro_module(tmp_path: Path): + module_path = tmp_path / "macro_alias_module.py" + module_path.write_text( + textwrap.dedent( + """ + from __future__ import annotations + + from sqlmesh import SQL as SQLType, macro + from sqlmesh.core.macros import MacroEvaluator, SQL + + @macro() + def plain_sql_macro( + evaluator: MacroEvaluator, + driver_property: str, + property_name: str, + fallback_value: SQL, + ) -> SQL: + assert isinstance(driver_property, str) + assert isinstance(property_name, str) + assert isinstance(fallback_value, str) + return fallback_value + + @macro() + def alias_sql_macro( + evaluator: MacroEvaluator, + driver_property: str, + property_name: str, + fallback_value: SQLType, + ) -> SQLType: + assert isinstance(driver_property, str) + assert isinstance(property_name, str) + assert isinstance(fallback_value, str) + return fallback_value + + @macro() + def partial_resolution_macro( + evaluator: MacroEvaluator, + good_value: str, + fallback_value: SQLType, + intentionally_unresolved: MissingAlias, + ) -> SQLType: + assert isinstance(good_value, str) + assert isinstance(fallback_value, str) + return fallback_value + """ + ) + ) + + spec = importlib.util.spec_from_file_location("macro_alias_module", module_path) + assert spec and spec.loader + + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + + return module, spec.name + + def test_star(assert_exp_eq) -> None: sql = """SELECT @STAR(foo) FROM foo""" expected_sql = "SELECT CAST([foo].[a] AS DATETIMEOFFSET) AS [a], CAST([foo].[b] AS INTEGER) AS [b] FROM foo" @@ -220,6 +284,52 @@ def test_case(macro_evaluator: MacroEvaluator) -> None: assert macro.get_registry()["upper"] +def test_macro_type_annotation_aliases(assert_exp_eq, tmp_path: Path) -> None: + module, module_name = load_alias_macro_module(tmp_path) + env: dict[str, t.Any] = {} + + try: + for macro_name in ("plain_sql_macro", "alias_sql_macro", "partial_resolution_macro"): + build_env(getattr(module, macro_name), env=env, name=macro_name, path=tmp_path) + + evaluator = MacroEvaluator("hive", python_env=serialize_env(env, path=tmp_path)) + + assert_exp_eq( + evaluator.transform( + parse_one( + "SELECT @plain_sql_macro('entitlement_sku', 'seat_id', es.seat_id) FROM foo es", + read="hive", + ) + ), + "SELECT es.seat_id FROM foo es", + dialect="hive", + ) + assert_exp_eq( + evaluator.transform( + parse_one( + "SELECT @alias_sql_macro('entitlement_sku', 'seat_id', es.seat_id) FROM foo es", + read="hive", + ) + ), + "SELECT es.seat_id FROM foo es", + dialect="hive", + ) + assert_exp_eq( + evaluator.transform( + parse_one( + "SELECT @partial_resolution_macro('entitlement_sku', 'seat_id', es.seat_id, 1) FROM foo es", + read="hive", + ) + ), + "SELECT es.seat_id FROM foo es", + dialect="hive", + ) + finally: + for macro_name in ("plain_sql_macro", "alias_sql_macro", "partial_resolution_macro"): + macro.registry().pop(macro_name, None) + sys.modules.pop(module_name, None) + + def test_macro_var(macro_evaluator): expression = parse_one("@x") for k, v in [ diff --git a/tests/utils/test_metaprogramming.py b/tests/utils/test_metaprogramming.py index 9a6f0c95cd..31c60b2acc 100644 --- a/tests/utils/test_metaprogramming.py +++ b/tests/utils/test_metaprogramming.py @@ -13,6 +13,7 @@ from sqlglot import exp as expressions from sqlglot.expressions import SQLGLOT_META, to_table from sqlglot.optimizer.pushdown_projections import SELECT_ALL +from sqlmesh import SQL as SQLType import tests.utils.test_date as test_date from sqlmesh.core.dialect import normalize_model_name @@ -178,6 +179,10 @@ def macro2() -> str: return "2" +def annotation_alias_func(value: str, fallback_value: SQLType) -> SQLType: + return fallback_value + + def test_func_globals() -> None: assert func_globals(main_func) == { "Y": 2, @@ -194,6 +199,7 @@ def test_func_globals() -> None: "function_with_custom_decorator": function_with_custom_decorator, "SQLGLOT_META": SQLGLOT_META, } + assert func_globals(annotation_alias_func) == {"SQLType": SQLType} assert func_globals(other_func) == { "X": 1, "W": 0,