Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions airflow-core/src/airflow/models/deadline.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,19 @@ def get_simple_context():
from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse
from airflow.models import DagRun

# TODO: Use the TaskAPI from within Triggerer to fetch full context instead of sending this context
# from the scheduler

# Fetch the DagRun from the database again to avoid errors when self.dagrun's relationship fields
# are not in the current session.
# TODO: Use the Execution API from within the triggerer/executor to fetch full context
# at execution time instead of sending this minimal context from the scheduler.
# This will allow template rendering with the standard Airflow Context and avoid
# bloating trigger kwargs with serialized context in the DB.
# Tracked at: https://github.com/apache/airflow/pull/64984

# Fetch the DagRun from the database again to avoid errors when self.dagrun's
# relationship fields are not in the current session.
dagrun = session.get(DagRun, self.dagrun_id)

return {
"dag_run": DAGRunResponse.model_validate(dagrun).model_dump(mode="json"),
"deadline": {"id": self.id, "deadline_time": self.deadline_time},
"deadline": {"id": str(self.id), "deadline_time": self.deadline_time},
}

if isinstance(self.callback, TriggererCallback):
Expand Down
29 changes: 28 additions & 1 deletion airflow-core/src/airflow/triggers/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import inspect
import logging
import traceback
from collections.abc import AsyncIterator
Expand All @@ -32,6 +33,21 @@
PAYLOAD_BODY_KEY = "body"


def _is_notifier_class(callback: Any) -> bool:
"""
Check if the callback is a BaseNotifier subclass (not an instance).

Uses duck-typing (checks for ``async_notify`` and ``template_fields``)
to avoid importing ``airflow.sdk`` in core.
"""
return (
inspect.isclass(callback)
and hasattr(callback, "async_notify")
and hasattr(callback, "template_fields")
and hasattr(callback, "__await__")
)


class CallbackTrigger(BaseTrigger):
"""Trigger that executes a callback function asynchronously."""

Expand All @@ -52,9 +68,20 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
try:
yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING})
callback = import_string(self.callback_path)
# TODO: get full context and run template rendering. Right now, a simple context is included in `callback_kwargs`

# TODO: Fetch full context via the Execution API at execution time rather than
# relying on the minimal context passed from the scheduler in callback_kwargs.
# This will provide fresh context and use the standard Airflow Context object,
# avoiding DB bloat from serialized context in trigger kwargs.
# Tracked at: https://github.com/apache/airflow/pull/64984
context = self.callback_kwargs.pop("context", None)

# Render Jinja templates in kwargs for plain function callbacks.
# Notifiers handle their own template rendering in __await__ via
# render_template_fields(), so we skip rendering here for them.
if context is not None and not _is_notifier_class(callback):
self.callback_kwargs = self.render_template(self.callback_kwargs, context)

if accepts_context(callback) and context is not None:
result = await callback(**self.callback_kwargs, context=context)
else:
Expand Down
48 changes: 47 additions & 1 deletion airflow-core/tests/unit/models/test_deadline.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,56 @@ def test_handle_miss(self, dagrun, session):
context = callback_kwargs.pop("context")
assert callback_kwargs == TEST_CALLBACK_KWARGS

assert context["deadline"]["id"] == deadline_orm.id
assert context["deadline"]["id"] == str(deadline_orm.id)
assert context["deadline"]["deadline_time"].timestamp() == deadline_orm.deadline_time.timestamp()
assert context["dag_run"] == DAGRunResponse.model_validate(dagrun).model_dump(mode="json")

@pytest.mark.db_test
def test_handle_miss_serializes_context_through_trigger(self, dagrun, session):
"""Verify that the enriched context can be serialized through the full trigger chain.

Exercises: handle_miss() -> TriggererCallback.queue() -> Trigger.from_object() -> serialize().
This catches type errors (UUIDs, datetimes, nested dicts) that would only surface at
trigger creation time, not when building the context dict.
"""
from airflow.models.trigger import Trigger

deadline_orm = Deadline(
deadline_time=DEFAULT_DATE,
callback=AsyncCallback(TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
dagrun_id=dagrun.id,
dag_id=dagrun.dag_id,
deadline_alert_id=None,
)
session.add(deadline_orm)
session.flush()

# Call handle_miss WITHOUT mocking queue() -- let the full serialization path execute
deadline_orm.handle_miss(session)
session.flush()

assert deadline_orm.missed

# The callback should now have a trigger created via Trigger.from_object()
assert deadline_orm.callback.trigger is not None
trigger = deadline_orm.callback.trigger
assert isinstance(trigger, Trigger)

# Verify the trigger was persisted and can be loaded back
session.refresh(trigger)
assert trigger.classpath == "airflow.triggers.callback.CallbackTrigger"

# Verify the serialized kwargs contain the simple context
# (this exercises the serialize() path with UUIDs, datetimes, nested dicts)
trigger_kwargs = trigger.kwargs
assert "callback_kwargs" in trigger_kwargs
callback_kwargs = trigger_kwargs["callback_kwargs"]
assert "context" in callback_kwargs
context = callback_kwargs["context"]
assert "dag_run" in context
assert "deadline" in context
assert context["deadline"]["id"] == str(deadline_orm.id)


@pytest.mark.db_test
class TestCalculatedDeadlineDatabaseCalls:
Expand Down
191 changes: 182 additions & 9 deletions airflow-core/tests/unit/triggers/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,31 @@

from airflow.models.callback import CallbackState
from airflow.sdk import BaseNotifier
from airflow.triggers.callback import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY, CallbackTrigger
from airflow.triggers.callback import (
PAYLOAD_BODY_KEY,
PAYLOAD_STATUS_KEY,
CallbackTrigger,
_is_notifier_class,
)

TEST_MESSAGE = "test_message"
TEST_CALLBACK_PATH = "classpath.test_callback"
TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": {"dag_run": "test"}}
TEST_CONTEXT = {
"dag_run": {"dag_id": "test_dag"},
"dag_id": "test_dag",
"run_id": "test_run",
"ds": "2024-01-01",
"ts": "2024-01-01T00:00:00+00:00",
"deadline": {"id": "abc-123", "deadline_time": "2024-01-01T01:00:00+00:00"},
}
TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": TEST_CONTEXT}


class ExampleAsyncNotifier(BaseNotifier):
"""Example of a properly implemented async notifier."""

template_fields = ("message",)

def __init__(self, message, **kwargs):
super().__init__(**kwargs)
self.message = message
Expand Down Expand Up @@ -80,7 +95,7 @@ def test_serialization(self, callback_init_kwargs, expected_serialized_kwargs):

@pytest.mark.asyncio
async def test_run_success_with_async_function(self, trigger, mock_import_string):
"""Test trigger handles async functions correctly."""
"""Test trigger handles async functions correctly and renders templates."""
callback_return_value = "some value"
mock_callback = mock.AsyncMock(return_value=callback_return_value)
mock_import_string.return_value = mock_callback
Expand All @@ -92,14 +107,14 @@ async def test_run_success_with_async_function(self, trigger, mock_import_string

success_event = await anext(trigger_gen)
mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
# AsyncMock accepts **kwargs, so _accepts_context returns True and context is passed through
mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
# Context is popped and passed separately; kwargs are rendered (no-op here since no templates)
mock_callback.assert_called_once_with(message=TEST_MESSAGE, context=TEST_CONTEXT)
assert success_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
assert success_event.payload[PAYLOAD_BODY_KEY] == callback_return_value

@pytest.mark.asyncio
async def test_run_success_with_notifier(self, trigger, mock_import_string):
"""Test trigger handles async notifier classes correctly."""
"""Test trigger handles async notifier classes correctly without pre-rendering."""
mock_import_string.return_value = ExampleAsyncNotifier

trigger_gen = trigger.run()
Expand All @@ -112,7 +127,7 @@ async def test_run_success_with_notifier(self, trigger, mock_import_string):
assert success_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
assert (
success_event.payload[PAYLOAD_BODY_KEY]
== f"Async notification: {TEST_MESSAGE}, context: {{'dag_run': 'test'}}"
== f"Async notification: {TEST_MESSAGE}, context: {TEST_CONTEXT}"
)

@pytest.mark.asyncio
Expand All @@ -128,7 +143,165 @@ async def test_run_failure(self, trigger, mock_import_string):

failure_event = await anext(trigger_gen)
mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
# AsyncMock accepts **kwargs, so _accepts_context returns True and context is passed through
mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
# Context is popped and passed separately; kwargs are rendered (no-op here since no templates)
mock_callback.assert_called_once_with(message=TEST_MESSAGE, context=TEST_CONTEXT)
assert failure_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.FAILED
assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in ["raise", "RuntimeError", exc_msg])


class TestTemplateRendering:
"""Tests for Jinja2 template rendering in callback kwargs."""

@pytest.mark.asyncio
async def test_run_renders_jinja_templates_in_function_kwargs(self):
"""Plain async function callbacks get their kwargs rendered."""
context = {"dag_id": "my_dag", "ds": "2024-06-15"}
trigger = CallbackTrigger(
callback_path="classpath.test",
callback_kwargs={
"message": "DAG {{ dag_id }} missed deadline at {{ ds }}",
"context": context,
},
)
mock_callback = mock.AsyncMock(return_value="ok")
with mock.patch("airflow.triggers.callback.import_string", return_value=mock_callback):
events = [event async for event in trigger.run()]

assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
mock_callback.assert_called_once_with(
message="DAG my_dag missed deadline at 2024-06-15",
context=context,
)

@pytest.mark.asyncio
async def test_run_does_not_double_render_notifier_kwargs(self):
"""Notifier classes should NOT have kwargs pre-rendered -- they handle it themselves."""
context = {"dag_id": "my_dag", "ds": "2024-06-15"}
trigger = CallbackTrigger(
callback_path="classpath.test",
callback_kwargs={
"message": "DAG {{ dag_id }}",
"context": context,
},
)
with mock.patch("airflow.triggers.callback.import_string", return_value=ExampleAsyncNotifier):
events = [event async for event in trigger.run()]

assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
# The notifier's __await__ renders template_fields, so the final output
# should show the rendered message (rendered by the notifier, not pre-rendered).
assert "DAG my_dag" in events[-1].payload[PAYLOAD_BODY_KEY]

@pytest.mark.asyncio
async def test_run_renders_nested_kwargs(self):
"""Template rendering works recursively on nested dicts and lists."""
context = {"dag_id": "etl_pipeline"}
trigger = CallbackTrigger(
callback_path="classpath.test",
callback_kwargs={
"recipients": ["{{ dag_id }}-team@example.com"],
"metadata": {"dag": "{{ dag_id }}"},
"context": context,
},
)
mock_callback = mock.AsyncMock(return_value="ok")
with mock.patch("airflow.triggers.callback.import_string", return_value=mock_callback):
events = [event async for event in trigger.run()]

assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
mock_callback.assert_called_once_with(
recipients=["etl_pipeline-team@example.com"],
metadata={"dag": "etl_pipeline"},
context=context,
)

@pytest.mark.asyncio
async def test_run_skips_rendering_when_no_context(self):
"""Without context, kwargs pass through unrendered."""
trigger = CallbackTrigger(
callback_path="classpath.test",
callback_kwargs={"message": "{{ dag_id }}"},
)
mock_callback = mock.AsyncMock(return_value="ok")
with mock.patch("airflow.triggers.callback.import_string", return_value=mock_callback):
events = [event async for event in trigger.run()]

assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
mock_callback.assert_called_once_with(message="{{ dag_id }}")

@pytest.mark.asyncio
async def test_notifier_template_fields_rendered_with_context(self):
"""Notifier template_fields are rendered using the provided context."""
context = {"dag_id": "my_dag", "ds": "2024-06-15"}
trigger = CallbackTrigger(
callback_path="classpath.test",
callback_kwargs={
"message": "Alert for {{ dag_id }} on {{ ds }}",
"context": context,
},
)
with mock.patch("airflow.triggers.callback.import_string", return_value=ExampleAsyncNotifier):
events = [event async for event in trigger.run()]

assert events[-1].payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS
# The notifier's __await__ renders template_fields (self.message), so the
# notification body contains the rendered message.
assert "Alert for my_dag on 2024-06-15" in events[-1].payload[PAYLOAD_BODY_KEY]


class TestHelpers:
"""Tests for module-level helper functions."""

def test_is_notifier_class_with_notifier(self):
assert _is_notifier_class(ExampleAsyncNotifier) is True

def test_is_notifier_class_with_function(self):
async def my_func():
pass

assert _is_notifier_class(my_func) is False

def test_is_notifier_class_with_non_notifier_class(self):
class MyClass:
pass

assert _is_notifier_class(MyClass) is False

def test_is_notifier_class_with_notifier_instance(self):
"""Instances are not classes -- should return False."""
instance = ExampleAsyncNotifier(message="hi")
assert _is_notifier_class(instance) is False

def test_render_template_renders_strings(self):
"""CallbackTrigger.render_template renders string values using context."""
trigger = CallbackTrigger(callback_path="", callback_kwargs={})
result = trigger.render_template(
{"message": "Hello {{ name }}", "count": 5},
{"name": "World"},
)
assert result == {"message": "Hello World", "count": 5}

def test_render_template_handles_nested_structures(self):
"""CallbackTrigger.render_template works recursively on nested structures."""
trigger = CallbackTrigger(callback_path="", callback_kwargs={})
result = trigger.render_template(
{"items": ["{{ x }}", "{{ y }}"], "meta": {"key": "{{ x }}"}},
{"x": "a", "y": "b"},
)
assert result == {"items": ["a", "b"], "meta": {"key": "a"}}

def test_render_template_missing_key_renders_empty(self):
"""Missing context keys render as empty strings (Jinja2 default undefined)."""
trigger = CallbackTrigger(callback_path="", callback_kwargs={})
result = trigger.render_template(
{"message": "Hello {{ nonexistent }}"},
{"name": "World"},
)
assert result == {"message": "Hello "}

def test_render_template_no_templates_is_noop(self):
"""Non-template strings and non-string values pass through unchanged."""
trigger = CallbackTrigger(callback_path="", callback_kwargs={})
kwargs = {"message": "plain text", "count": 42}
result = trigger.render_template(kwargs, {"dag_id": "test"})
assert result == kwargs
Loading