Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ repos:
^src/airflow/timetables/base\.py$|
^src/airflow/timetables/simple\.py$|
^src/airflow/triggers/base\.py$|
^src/airflow/triggers/callback\.py$|
^src/airflow/utils/cli\.py$|
^src/airflow/utils/context\.py$|
^src/airflow/utils/dag_cycle_tester\.py$|
Expand Down
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@
from airflow.sdk.execution_time.comms import (
CommsDecoder,
ConnectionResult,
DagRunResult,
DagRunStateResult,
DeleteVariable,
DeleteXCom,
DRCount,
ErrorResponse,
GetConnection,
GetDagRun,
GetDagRunState,
GetDRCount,
GetHITLDetailResponse,
Expand Down Expand Up @@ -303,6 +305,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe
| ConnectionResult
| VariableResult
| XComResult
| DagRunResult
| DagRunStateResult
| DRCount
| TICount
Expand All @@ -329,6 +332,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe
| SetXCom
| GetTICount
| GetTaskStates
| GetDagRun
| GetDagRunState
| GetDRCount
| GetPreviousTI
Expand Down Expand Up @@ -565,6 +569,9 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r
states=msg.states,
)
resp = dr_count
elif isinstance(msg, GetDagRun):
dr_resp = self.client.dag_runs.get_detail(msg.dag_id, msg.run_id)
resp = DagRunResult.from_api_response(dr_resp)
elif isinstance(msg, GetDagRunState):
dr_resp = self.client.dag_runs.get_state(msg.dag_id, msg.run_id)
resp = DagRunStateResult.from_api_response(dr_resp)
Expand Down
35 changes: 12 additions & 23 deletions airflow-core/src/airflow/models/deadline.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,30 +215,21 @@ def prune_deadlines(cls, *, session: Session, conditions: dict[Mapped, Any]) ->

def handle_miss(self, session: Session):
"""Handle a missed deadline by queueing the callback."""

def get_simple_context():
from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse
from airflow.models import DagRun

# TODO: Use the TaskAPI from within Triggerer to fetch full context instead of sending this context
# from the scheduler

# Fetch the DagRun from the database again to avoid errors when self.dagrun's relationship fields
# are not in the current session.
dagrun = session.get(DagRun, self.dagrun_id)

return {
"dag_run": DAGRunResponse.model_validate(dagrun).model_dump(mode="json"),
"deadline": {"id": self.id, "deadline_time": self.deadline_time},
}
# Store only identifiers in kwargs; the callback executor (triggerer or executor subprocess)
# fetches the full DagRun context via the Execution API at runtime. This avoids DB bloat
# from serialized context and ensures context is fresh at execution time.
context_identifiers = {
"dag_id": self.dagrun.dag_id,
"run_id": self.dagrun.run_id,
"deadline_id": str(self.id),
"deadline_time": self.deadline_time.isoformat(),
}

if isinstance(self.callback, TriggererCallback):
# Update the callback with context before queuing
# Update the callback with identifiers before queuing
if "kwargs" not in self.callback.data:
self.callback.data["kwargs"] = {}
self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | {
"context": get_simple_context()
}
self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | context_identifiers

self.callback.queue()
session.add(self.callback)
Expand All @@ -247,9 +238,7 @@ def get_simple_context():
elif isinstance(self.callback, ExecutorCallback):
if "kwargs" not in self.callback.data:
self.callback.data["kwargs"] = {}
self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | {
"context": get_simple_context()
}
self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | context_identifiers
self.callback.data["deadline_id"] = str(self.id)
self.callback.data["dag_run_id"] = str(self.dagrun.id)
self.callback.data["dag_id"] = self.dagrun.dag_id
Expand Down
52 changes: 51 additions & 1 deletion airflow-core/src/airflow/triggers/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,63 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
{attr: getattr(self, attr) for attr in ("callback_path", "callback_kwargs")},
)

async def _build_context(
self, dag_id: str, run_id: str, deadline_id: str | None, deadline_time: str | None
) -> dict[str, Any]:
"""
Fetch the DagRun via the Execution API and build a context dict for the callback.

This replaces the previous approach of storing a serialized context in the database
at scheduling time. Fetching at execution time ensures the context is fresh and avoids
DB bloat from large serialized payloads.
"""
from airflow.sdk.execution_time.comms import DagRunResult, GetDagRun
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

response = await SUPERVISOR_COMMS.asend(GetDagRun(dag_id=dag_id, run_id=run_id))
if not isinstance(response, DagRunResult):
log.warning("Unexpected response type from GetDagRun: %s", type(response))
return {}

context: dict[str, Any] = {
"dag_run": response.model_dump(mode="json"),
"dag_id": dag_id,
"run_id": run_id,
"logical_date": response.logical_date.isoformat() if response.logical_date else None,
"data_interval_start": (
response.data_interval_start.isoformat() if response.data_interval_start else None
),
"data_interval_end": (
response.data_interval_end.isoformat() if response.data_interval_end else None
),
"conf": response.conf,
}

if deadline_id or deadline_time:
context["deadline"] = {
k: v for k, v in {"id": deadline_id, "deadline_time": deadline_time}.items() if v is not None
}
Comment thread
ferruzzi marked this conversation as resolved.

return context

async def run(self) -> AsyncIterator[TriggerEvent]:
try:
yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING})
callback = import_string(self.callback_path)
# TODO: get full context and run template rendering. Right now, a simple context is included in `callback_kwargs`

# Backward compat: if a pre-upgrade callback stored "context" directly in kwargs, use it.
context = self.callback_kwargs.pop("context", None)

if context is None:
# New path: fetch context via the Execution API using stored identifiers.
dag_id = self.callback_kwargs.pop("dag_id", None)
run_id = self.callback_kwargs.pop("run_id", None)
deadline_id = self.callback_kwargs.pop("deadline_id", None)
deadline_time = self.callback_kwargs.pop("deadline_time", None)

if dag_id and run_id:
context = await self._build_context(dag_id, run_id, deadline_id, deadline_time)

if accepts_context(callback) and context is not None:
result = await callback(**self.callback_kwargs, context=context)
else:
Expand Down
2 changes: 0 additions & 2 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,7 +1776,6 @@ def get_type_names(union_type):
"GetAssetsByAlias",
"GetAssetEventByAsset",
"GetAssetEventByAssetAlias",
"GetDagRun",
"GetPrevSuccessfulDagRun",
"GetPreviousDagRun",
"GetTaskBreadcrumbs",
Expand Down Expand Up @@ -1815,7 +1814,6 @@ def get_type_names(union_type):
"AssetResult",
"AssetsByAliasResult",
"AssetEventsResult",
"DagRunResult",
"SentFDs",
"StartupDetails",
"TaskBreadcrumbsResult",
Expand Down
18 changes: 12 additions & 6 deletions airflow-core/tests/unit/models/test_deadline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError

from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse
from airflow.models import DagRun
from airflow.models.deadline import Deadline, _fetch_from_db
from airflow.providers.standard.operators.empty import EmptyOperator
Expand Down Expand Up @@ -231,12 +230,19 @@ def test_handle_miss(self, dagrun, session):
assert deadline_orm.missed

callback_kwargs = deadline_orm.callback.data["kwargs"]
context = callback_kwargs.pop("context")
assert callback_kwargs == TEST_CALLBACK_KWARGS

assert context["deadline"]["id"] == deadline_orm.id
assert context["deadline"]["deadline_time"].timestamp() == deadline_orm.deadline_time.timestamp()
assert context["dag_run"] == DAGRunResponse.model_validate(dagrun).model_dump(mode="json")
# Verify that identifiers (not full context) are stored in kwargs
assert callback_kwargs["dag_id"] == dagrun.dag_id
assert callback_kwargs["run_id"] == dagrun.run_id
assert callback_kwargs["deadline_id"] == str(deadline_orm.id)
assert callback_kwargs["deadline_time"] == deadline_orm.deadline_time.isoformat()

# The original user-provided kwargs should still be present
for key, value in TEST_CALLBACK_KWARGS.items():
assert callback_kwargs[key] == value

# No serialized "context" key should be stored
assert "context" not in callback_kwargs


@pytest.mark.db_test
Expand Down
85 changes: 85 additions & 0 deletions airflow-core/tests/unit/triggers/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
TEST_MESSAGE = "test_message"
TEST_CALLBACK_PATH = "classpath.test_callback"
TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": {"dag_run": "test"}}
TEST_DAG_ID = "test_dag"
TEST_RUN_ID = "test_run_2024"
TEST_DEADLINE_ID = "abc-123"
TEST_DEADLINE_TIME = "2024-01-01T00:00:00+00:00"


class ExampleAsyncNotifier(BaseNotifier):
Expand Down Expand Up @@ -132,3 +136,84 @@ async def test_run_failure(self, trigger, mock_import_string):
mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
assert failure_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.FAILED
assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in ["raise", "RuntimeError", exc_msg])

@pytest.mark.asyncio
async def test_run_fetches_context_via_execution_api(self, mock_import_string):
"""When kwargs contain identifiers (dag_id, run_id) but no 'context', fetch via API."""
from airflow.sdk.execution_time.comms import DagRunResult

mock_callback = mock.AsyncMock(return_value="done")
mock_import_string.return_value = mock_callback

trigger = CallbackTrigger(
callback_path=TEST_CALLBACK_PATH,
callback_kwargs={
"message": TEST_MESSAGE,
"dag_id": TEST_DAG_ID,
"run_id": TEST_RUN_ID,
"deadline_id": TEST_DEADLINE_ID,
"deadline_time": TEST_DEADLINE_TIME,
},
)

# Create a mock DagRunResult response
mock_dag_run_result = DagRunResult(
dag_id=TEST_DAG_ID,
run_id=TEST_RUN_ID,
logical_date="2024-01-01T00:00:00+00:00",
data_interval_start="2024-01-01T00:00:00+00:00",
data_interval_end="2024-01-02T00:00:00+00:00",
run_after="2024-01-01T00:00:00+00:00",
run_type="manual",
state="running",
conf={"key": "value"},
consumed_asset_events=[],
)

with mock.patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True) as mock_comms:
mock_comms.asend = mock.AsyncMock(return_value=mock_dag_run_result)

trigger_gen = trigger.run()
running_event = await anext(trigger_gen)
assert running_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.RUNNING

success_event = await anext(trigger_gen)
assert success_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS

# Verify the callback was called with context (since AsyncMock accepts context)
call_kwargs = mock_callback.call_args[1]
assert call_kwargs["message"] == TEST_MESSAGE
assert "context" in call_kwargs
context = call_kwargs["context"]
assert context["dag_id"] == TEST_DAG_ID
assert context["run_id"] == TEST_RUN_ID
assert context["deadline"]["id"] == TEST_DEADLINE_ID
assert context["deadline"]["deadline_time"] == TEST_DEADLINE_TIME
assert context["conf"] == {"key": "value"}

@pytest.mark.asyncio
async def test_run_backward_compat_with_stored_context(self, mock_import_string):
"""When kwargs contain 'context' directly (pre-upgrade callbacks), use it as-is."""
mock_callback = mock.AsyncMock(return_value="done")
mock_import_string.return_value = mock_callback

legacy_context = {"dag_run": {"dag_id": "old_dag"}, "deadline": {"id": "old-id"}}
trigger = CallbackTrigger(
callback_path=TEST_CALLBACK_PATH,
callback_kwargs={
"message": TEST_MESSAGE,
"context": legacy_context,
},
)

trigger_gen = trigger.run()
running_event = await anext(trigger_gen)
assert running_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.RUNNING

success_event = await anext(trigger_gen)
assert success_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS

# Verify the callback was called with the legacy context directly
call_kwargs = mock_callback.call_args[1]
assert call_kwargs["context"] == legacy_context
assert call_kwargs["message"] == TEST_MESSAGE
11 changes: 8 additions & 3 deletions task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
from airflow.sdk._shared.module_loading import accepts_context, accepts_keyword_args
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
DagRunResult,
ErrorResponse,
GetConnection,
GetDagRun,
GetVariable,
MaskSecret,
)
Expand Down Expand Up @@ -69,10 +71,10 @@ class _BundleInfoLike(Protocol):


# The set of messages that a callback subprocess can send to the supervisor.
# This is a minimal subset of ToSupervisor: read-only access to Connections
# and Variables, plus MaskSecret for the secrets masker.
# This is a minimal subset of ToSupervisor: read-only access to Connections,
# Variables, and DagRun details, plus MaskSecret for the secrets masker.
CallbackToSupervisor = Annotated[
GetConnection | GetVariable | MaskSecret,
GetConnection | GetDagRun | GetVariable | MaskSecret,
Field(discriminator="type"),
]

Expand Down Expand Up @@ -282,6 +284,9 @@ def _handle_request(self, msg: CallbackToSupervisor, log: FilteringBoundLogger,

if isinstance(msg, GetConnection):
resp, dump_opts = handle_get_connection(self.client, msg)
elif isinstance(msg, GetDagRun):
dr_resp = self.client.dag_runs.get_detail(msg.dag_id, msg.run_id)
resp = DagRunResult.from_api_response(dr_resp)
elif isinstance(msg, GetVariable):
resp, dump_opts = handle_get_variable(self.client, msg)
elif isinstance(msg, MaskSecret):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,20 @@

import socket
from dataclasses import dataclass
from datetime import datetime, timezone
from operator import attrgetter
from typing import Any
from unittest.mock import patch

import pytest
import structlog

from airflow.sdk.api.datamodels._generated import DagRun, DagRunState, DagRunType
from airflow.sdk.execution_time.callback_supervisor import CallbackSubprocess, execute_callback
from airflow.sdk.execution_time.comms import (
ConnectionResult,
GetConnection,
GetDagRun,
GetVariable,
MaskSecret,
VariableResult,
Expand Down Expand Up @@ -171,6 +174,22 @@ class RequestCase:
),
mask_secret_args=("secret",),
),
RequestCase(
message=GetDagRun(dag_id="test_dag", run_id="test_run_1"),
test_id="get_dag_run",
client_mock=ClientMock(
method_path="dag_runs.get_detail",
args=("test_dag", "test_run_1"),
response=DagRun(
dag_id="test_dag",
run_id="test_run_1",
run_after=datetime(2024, 1, 1, tzinfo=timezone.utc),
run_type=DagRunType.MANUAL,
state=DagRunState.RUNNING,
consumed_asset_events=[],
),
),
),
RequestCase(
message=GetVariable(key="test_key"),
test_id="get_variable",
Expand Down