diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index 13351c72c77c6..af5263654798c 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -220,16 +220,19 @@ def get_simple_context(): from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse from airflow.models import DagRun - # TODO: Use the TaskAPI from within Triggerer to fetch full context instead of sending this context - # from the scheduler - - # Fetch the DagRun from the database again to avoid errors when self.dagrun's relationship fields - # are not in the current session. + # TODO: Use the Execution API from within the triggerer/executor to fetch full context + # at execution time instead of sending this minimal context from the scheduler. + # This will allow template rendering with the standard Airflow Context and avoid + # bloating trigger kwargs with serialized context in the DB. + # Tracked at: https://github.com/apache/airflow/pull/64984 + + # Fetch the DagRun from the database again to avoid errors when self.dagrun's + # relationship fields are not in the current session. dagrun = session.get(DagRun, self.dagrun_id) return { "dag_run": DAGRunResponse.model_validate(dagrun).model_dump(mode="json"), - "deadline": {"id": self.id, "deadline_time": self.deadline_time}, + "deadline": {"id": str(self.id), "deadline_time": self.deadline_time}, } if isinstance(self.callback, TriggererCallback): diff --git a/airflow-core/src/airflow/triggers/callback.py b/airflow-core/src/airflow/triggers/callback.py index b27920fed8614..18c3654074b97 100644 --- a/airflow-core/src/airflow/triggers/callback.py +++ b/airflow-core/src/airflow/triggers/callback.py @@ -17,6 +17,7 @@ from __future__ import annotations +import inspect import logging import traceback from collections.abc import AsyncIterator @@ -32,6 +33,21 @@ PAYLOAD_BODY_KEY = "body" +def _is_notifier_class(callback: Any) -> bool: + """ + Check if the callback is a BaseNotifier subclass (not an instance). + + Uses duck-typing (checks for ``async_notify`` and ``template_fields``) + to avoid importing ``airflow.sdk`` in core. + """ + return ( + inspect.isclass(callback) + and hasattr(callback, "async_notify") + and hasattr(callback, "template_fields") + and hasattr(callback, "__await__") + ) + + class CallbackTrigger(BaseTrigger): """Trigger that executes a callback function asynchronously.""" @@ -52,9 +68,20 @@ async def run(self) -> AsyncIterator[TriggerEvent]: try: yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING}) callback = import_string(self.callback_path) - # TODO: get full context and run template rendering. Right now, a simple context is included in `callback_kwargs` + + # TODO: Fetch full context via the Execution API at execution time rather than + # relying on the minimal context passed from the scheduler in callback_kwargs. + # This will provide fresh context and use the standard Airflow Context object, + # avoiding DB bloat from serialized context in trigger kwargs. + # Tracked at: https://github.com/apache/airflow/pull/64984 context = self.callback_kwargs.pop("context", None) + # Render Jinja templates in kwargs for plain function callbacks. + # Notifiers handle their own template rendering in __await__ via + # render_template_fields(), so we skip rendering here for them. + if context is not None and not _is_notifier_class(callback): + self.callback_kwargs = self.render_template(self.callback_kwargs, context) + if accepts_context(callback) and context is not None: result = await callback(**self.callback_kwargs, context=context) else: diff --git a/airflow-core/tests/unit/models/test_deadline.py b/airflow-core/tests/unit/models/test_deadline.py index 94c6977ae0c16..5251de16c2aaa 100644 --- a/airflow-core/tests/unit/models/test_deadline.py +++ b/airflow-core/tests/unit/models/test_deadline.py @@ -234,10 +234,56 @@ def test_handle_miss(self, dagrun, session): context = callback_kwargs.pop("context") assert callback_kwargs == TEST_CALLBACK_KWARGS - assert context["deadline"]["id"] == deadline_orm.id + assert context["deadline"]["id"] == str(deadline_orm.id) assert context["deadline"]["deadline_time"].timestamp() == deadline_orm.deadline_time.timestamp() assert context["dag_run"] == DAGRunResponse.model_validate(dagrun).model_dump(mode="json") + @pytest.mark.db_test + def test_handle_miss_serializes_context_through_trigger(self, dagrun, session): + """Verify that the enriched context can be serialized through the full trigger chain. + + Exercises: handle_miss() -> TriggererCallback.queue() -> Trigger.from_object() -> serialize(). + This catches type errors (UUIDs, datetimes, nested dicts) that would only surface at + trigger creation time, not when building the context dict. + """ + from airflow.models.trigger import Trigger + + deadline_orm = Deadline( + deadline_time=DEFAULT_DATE, + callback=AsyncCallback(TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + dagrun_id=dagrun.id, + dag_id=dagrun.dag_id, + deadline_alert_id=None, + ) + session.add(deadline_orm) + session.flush() + + # Call handle_miss WITHOUT mocking queue() -- let the full serialization path execute + deadline_orm.handle_miss(session) + session.flush() + + assert deadline_orm.missed + + # The callback should now have a trigger created via Trigger.from_object() + assert deadline_orm.callback.trigger is not None + trigger = deadline_orm.callback.trigger + assert isinstance(trigger, Trigger) + + # Verify the trigger was persisted and can be loaded back + session.refresh(trigger) + assert trigger.classpath == "airflow.triggers.callback.CallbackTrigger" + + # Verify the serialized kwargs contain the simple context + # (this exercises the serialize() path with UUIDs, datetimes, nested dicts) + trigger_kwargs = trigger.kwargs + assert "callback_kwargs" in trigger_kwargs + callback_kwargs = trigger_kwargs["callback_kwargs"] + assert "context" in callback_kwargs + context = callback_kwargs["context"] + assert "dag_run" in context + assert "deadline" in context + assert context["deadline"]["id"] == str(deadline_orm.id) + @pytest.mark.db_test class TestCalculatedDeadlineDatabaseCalls: diff --git a/airflow-core/tests/unit/triggers/test_callback.py b/airflow-core/tests/unit/triggers/test_callback.py index 99eca603323bb..670e3f14b65a5 100644 --- a/airflow-core/tests/unit/triggers/test_callback.py +++ b/airflow-core/tests/unit/triggers/test_callback.py @@ -23,16 +23,31 @@ from airflow.models.callback import CallbackState from airflow.sdk import BaseNotifier -from airflow.triggers.callback import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY, CallbackTrigger +from airflow.triggers.callback import ( + PAYLOAD_BODY_KEY, + PAYLOAD_STATUS_KEY, + CallbackTrigger, + _is_notifier_class, +) TEST_MESSAGE = "test_message" TEST_CALLBACK_PATH = "classpath.test_callback" -TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": {"dag_run": "test"}} +TEST_CONTEXT = { + "dag_run": {"dag_id": "test_dag"}, + "dag_id": "test_dag", + "run_id": "test_run", + "ds": "2024-01-01", + "ts": "2024-01-01T00:00:00+00:00", + "deadline": {"id": "abc-123", "deadline_time": "2024-01-01T01:00:00+00:00"}, +} +TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": TEST_CONTEXT} class ExampleAsyncNotifier(BaseNotifier): """Example of a properly implemented async notifier.""" + template_fields = ("message",) + def __init__(self, message, **kwargs): super().__init__(**kwargs) self.message = message @@ -80,7 +95,7 @@ def test_serialization(self, callback_init_kwargs, expected_serialized_kwargs): @pytest.mark.asyncio async def test_run_success_with_async_function(self, trigger, mock_import_string): - """Test trigger handles async functions correctly.""" + """Test trigger handles async functions correctly and renders templates.""" callback_return_value = "some value" mock_callback = mock.AsyncMock(return_value=callback_return_value) mock_import_string.return_value = mock_callback @@ -92,14 +107,14 @@ async def test_run_success_with_async_function(self, trigger, mock_import_string success_event = await anext(trigger_gen) mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH) - # AsyncMock accepts **kwargs, so _accepts_context returns True and context is passed through - mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS) + # Context is popped and passed separately; kwargs are rendered (no-op here since no templates) + mock_callback.assert_called_once_with(message=TEST_MESSAGE, context=TEST_CONTEXT) assert success_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS assert success_event.payload[PAYLOAD_BODY_KEY] == callback_return_value @pytest.mark.asyncio async def test_run_success_with_notifier(self, trigger, mock_import_string): - """Test trigger handles async notifier classes correctly.""" + """Test trigger handles async notifier classes correctly without pre-rendering.""" mock_import_string.return_value = ExampleAsyncNotifier trigger_gen = trigger.run() @@ -112,7 +127,7 @@ async def test_run_success_with_notifier(self, trigger, mock_import_string): assert success_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS assert ( success_event.payload[PAYLOAD_BODY_KEY] - == f"Async notification: {TEST_MESSAGE}, context: {{'dag_run': 'test'}}" + == f"Async notification: {TEST_MESSAGE}, context: {TEST_CONTEXT}" ) @pytest.mark.asyncio @@ -128,7 +143,165 @@ async def test_run_failure(self, trigger, mock_import_string): failure_event = await anext(trigger_gen) mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH) - # AsyncMock accepts **kwargs, so _accepts_context returns True and context is passed through - mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS) + # Context is popped and passed separately; kwargs are rendered (no-op here since no templates) + mock_callback.assert_called_once_with(message=TEST_MESSAGE, context=TEST_CONTEXT) assert failure_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.FAILED assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in ["raise", "RuntimeError", exc_msg]) + + +class TestTemplateRendering: + """Tests for Jinja2 template rendering in callback kwargs.""" + + @pytest.mark.asyncio + async def test_run_renders_jinja_templates_in_function_kwargs(self): + """Plain async function callbacks get their kwargs rendered.""" + context = {"dag_id": "my_dag", "ds": "2024-06-15"} + trigger = CallbackTrigger( + callback_path="classpath.test", + callback_kwargs={ + "message": "DAG {{ dag_id }} missed deadline at {{ ds }}", + "context": context, + }, + ) + mock_callback = mock.AsyncMock(return_value="ok") + with mock.patch("airflow.triggers.callback.import_string", return_value=mock_callback): + events = [event async for event in trigger.run()] + + assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS + mock_callback.assert_called_once_with( + message="DAG my_dag missed deadline at 2024-06-15", + context=context, + ) + + @pytest.mark.asyncio + async def test_run_does_not_double_render_notifier_kwargs(self): + """Notifier classes should NOT have kwargs pre-rendered -- they handle it themselves.""" + context = {"dag_id": "my_dag", "ds": "2024-06-15"} + trigger = CallbackTrigger( + callback_path="classpath.test", + callback_kwargs={ + "message": "DAG {{ dag_id }}", + "context": context, + }, + ) + with mock.patch("airflow.triggers.callback.import_string", return_value=ExampleAsyncNotifier): + events = [event async for event in trigger.run()] + + assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS + # The notifier's __await__ renders template_fields, so the final output + # should show the rendered message (rendered by the notifier, not pre-rendered). + assert "DAG my_dag" in events[-1].payload[PAYLOAD_BODY_KEY] + + @pytest.mark.asyncio + async def test_run_renders_nested_kwargs(self): + """Template rendering works recursively on nested dicts and lists.""" + context = {"dag_id": "etl_pipeline"} + trigger = CallbackTrigger( + callback_path="classpath.test", + callback_kwargs={ + "recipients": ["{{ dag_id }}-team@example.com"], + "metadata": {"dag": "{{ dag_id }}"}, + "context": context, + }, + ) + mock_callback = mock.AsyncMock(return_value="ok") + with mock.patch("airflow.triggers.callback.import_string", return_value=mock_callback): + events = [event async for event in trigger.run()] + + assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS + mock_callback.assert_called_once_with( + recipients=["etl_pipeline-team@example.com"], + metadata={"dag": "etl_pipeline"}, + context=context, + ) + + @pytest.mark.asyncio + async def test_run_skips_rendering_when_no_context(self): + """Without context, kwargs pass through unrendered.""" + trigger = CallbackTrigger( + callback_path="classpath.test", + callback_kwargs={"message": "{{ dag_id }}"}, + ) + mock_callback = mock.AsyncMock(return_value="ok") + with mock.patch("airflow.triggers.callback.import_string", return_value=mock_callback): + events = [event async for event in trigger.run()] + + assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS + mock_callback.assert_called_once_with(message="{{ dag_id }}") + + @pytest.mark.asyncio + async def test_notifier_template_fields_rendered_with_context(self): + """Notifier template_fields are rendered using the provided context.""" + context = {"dag_id": "my_dag", "ds": "2024-06-15"} + trigger = CallbackTrigger( + callback_path="classpath.test", + callback_kwargs={ + "message": "Alert for {{ dag_id }} on {{ ds }}", + "context": context, + }, + ) + with mock.patch("airflow.triggers.callback.import_string", return_value=ExampleAsyncNotifier): + events = [event async for event in trigger.run()] + + assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS + # The notifier's __await__ renders template_fields (self.message), so the + # notification body contains the rendered message. + assert "Alert for my_dag on 2024-06-15" in events[-1].payload[PAYLOAD_BODY_KEY] + + +class TestHelpers: + """Tests for module-level helper functions.""" + + def test_is_notifier_class_with_notifier(self): + assert _is_notifier_class(ExampleAsyncNotifier) is True + + def test_is_notifier_class_with_function(self): + async def my_func(): + pass + + assert _is_notifier_class(my_func) is False + + def test_is_notifier_class_with_non_notifier_class(self): + class MyClass: + pass + + assert _is_notifier_class(MyClass) is False + + def test_is_notifier_class_with_notifier_instance(self): + """Instances are not classes -- should return False.""" + instance = ExampleAsyncNotifier(message="hi") + assert _is_notifier_class(instance) is False + + def test_render_template_renders_strings(self): + """CallbackTrigger.render_template renders string values using context.""" + trigger = CallbackTrigger(callback_path="", callback_kwargs={}) + result = trigger.render_template( + {"message": "Hello {{ name }}", "count": 5}, + {"name": "World"}, + ) + assert result == {"message": "Hello World", "count": 5} + + def test_render_template_handles_nested_structures(self): + """CallbackTrigger.render_template works recursively on nested structures.""" + trigger = CallbackTrigger(callback_path="", callback_kwargs={}) + result = trigger.render_template( + {"items": ["{{ x }}", "{{ y }}"], "meta": {"key": "{{ x }}"}}, + {"x": "a", "y": "b"}, + ) + assert result == {"items": ["a", "b"], "meta": {"key": "a"}} + + def test_render_template_missing_key_renders_empty(self): + """Missing context keys render as empty strings (Jinja2 default undefined).""" + trigger = CallbackTrigger(callback_path="", callback_kwargs={}) + result = trigger.render_template( + {"message": "Hello {{ nonexistent }}"}, + {"name": "World"}, + ) + assert result == {"message": "Hello "} + + def test_render_template_no_templates_is_noop(self): + """Non-template strings and non-string values pass through unchanged.""" + trigger = CallbackTrigger(callback_path="", callback_kwargs={}) + kwargs = {"message": "plain text", "count": 42} + result = trigger.render_template(kwargs, {"dag_id": "test"}) + assert result == kwargs