diff --git a/airflow-core/.pre-commit-config.yaml b/airflow-core/.pre-commit-config.yaml index 838747ab24208..ce3352adaa5d3 100644 --- a/airflow-core/.pre-commit-config.yaml +++ b/airflow-core/.pre-commit-config.yaml @@ -388,6 +388,7 @@ repos: ^src/airflow/timetables/base\.py$| ^src/airflow/timetables/simple\.py$| ^src/airflow/triggers/base\.py$| + ^src/airflow/triggers/callback\.py$| ^src/airflow/utils/cli\.py$| ^src/airflow/utils/context\.py$| ^src/airflow/utils/dag_cycle_tester\.py$| diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 9edde3b276e56..e05601d0cdd6a 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -61,12 +61,14 @@ from airflow.sdk.execution_time.comms import ( CommsDecoder, ConnectionResult, + DagRunResult, DagRunStateResult, DeleteVariable, DeleteXCom, DRCount, ErrorResponse, GetConnection, + GetDagRun, GetDagRunState, GetDRCount, GetHITLDetailResponse, @@ -303,6 +305,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe | ConnectionResult | VariableResult | XComResult + | DagRunResult | DagRunStateResult | DRCount | TICount @@ -329,6 +332,7 @@ def from_api_response(cls, response: HITLDetailResponse) -> HITLDetailResponseRe | SetXCom | GetTICount | GetTaskStates + | GetDagRun | GetDagRunState | GetDRCount | GetPreviousTI @@ -565,6 +569,9 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r states=msg.states, ) resp = dr_count + elif isinstance(msg, GetDagRun): + dr_resp = self.client.dag_runs.get_detail(msg.dag_id, msg.run_id) + resp = DagRunResult.from_api_response(dr_resp) elif isinstance(msg, GetDagRunState): dr_resp = self.client.dag_runs.get_state(msg.dag_id, msg.run_id) resp = DagRunStateResult.from_api_response(dr_resp) diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index 13351c72c77c6..1f30c75e2a39b 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -215,30 +215,21 @@ def prune_deadlines(cls, *, session: Session, conditions: dict[Mapped, Any]) -> def handle_miss(self, session: Session): """Handle a missed deadline by queueing the callback.""" - - def get_simple_context(): - from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse - from airflow.models import DagRun - - # TODO: Use the TaskAPI from within Triggerer to fetch full context instead of sending this context - # from the scheduler - - # Fetch the DagRun from the database again to avoid errors when self.dagrun's relationship fields - # are not in the current session. - dagrun = session.get(DagRun, self.dagrun_id) - - return { - "dag_run": DAGRunResponse.model_validate(dagrun).model_dump(mode="json"), - "deadline": {"id": self.id, "deadline_time": self.deadline_time}, - } + # Store only identifiers in kwargs; the callback executor (triggerer or executor subprocess) + # fetches the full DagRun context via the Execution API at runtime. This avoids DB bloat + # from serialized context and ensures context is fresh at execution time. + context_identifiers = { + "dag_id": self.dagrun.dag_id, + "run_id": self.dagrun.run_id, + "deadline_id": str(self.id), + "deadline_time": self.deadline_time.isoformat(), + } if isinstance(self.callback, TriggererCallback): - # Update the callback with context before queuing + # Update the callback with identifiers before queuing if "kwargs" not in self.callback.data: self.callback.data["kwargs"] = {} - self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | { - "context": get_simple_context() - } + self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | context_identifiers self.callback.queue() session.add(self.callback) @@ -247,9 +238,7 @@ def get_simple_context(): elif isinstance(self.callback, ExecutorCallback): if "kwargs" not in self.callback.data: self.callback.data["kwargs"] = {} - self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | { - "context": get_simple_context() - } + self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | context_identifiers self.callback.data["deadline_id"] = str(self.id) self.callback.data["dag_run_id"] = str(self.dagrun.id) self.callback.data["dag_id"] = self.dagrun.dag_id diff --git a/airflow-core/src/airflow/triggers/callback.py b/airflow-core/src/airflow/triggers/callback.py index b27920fed8614..17d5dde84f1ea 100644 --- a/airflow-core/src/airflow/triggers/callback.py +++ b/airflow-core/src/airflow/triggers/callback.py @@ -48,13 +48,63 @@ def serialize(self) -> tuple[str, dict[str, Any]]: {attr: getattr(self, attr) for attr in ("callback_path", "callback_kwargs")}, ) + async def _build_context( + self, dag_id: str, run_id: str, deadline_id: str | None, deadline_time: str | None + ) -> dict[str, Any]: + """ + Fetch the DagRun via the Execution API and build a context dict for the callback. + + This replaces the previous approach of storing a serialized context in the database + at scheduling time. Fetching at execution time ensures the context is fresh and avoids + DB bloat from large serialized payloads. + """ + from airflow.sdk.execution_time.comms import DagRunResult, GetDagRun + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + response = await SUPERVISOR_COMMS.asend(GetDagRun(dag_id=dag_id, run_id=run_id)) + if not isinstance(response, DagRunResult): + log.warning("Unexpected response type from GetDagRun: %s", type(response)) + return {} + + context: dict[str, Any] = { + "dag_run": response.model_dump(mode="json"), + "dag_id": dag_id, + "run_id": run_id, + "logical_date": response.logical_date.isoformat() if response.logical_date else None, + "data_interval_start": ( + response.data_interval_start.isoformat() if response.data_interval_start else None + ), + "data_interval_end": ( + response.data_interval_end.isoformat() if response.data_interval_end else None + ), + "conf": response.conf, + } + + if deadline_id or deadline_time: + context["deadline"] = { + k: v for k, v in {"id": deadline_id, "deadline_time": deadline_time}.items() if v is not None + } + + return context + async def run(self) -> AsyncIterator[TriggerEvent]: try: yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING}) callback = import_string(self.callback_path) - # TODO: get full context and run template rendering. Right now, a simple context is included in `callback_kwargs` + + # Backward compat: if a pre-upgrade callback stored "context" directly in kwargs, use it. context = self.callback_kwargs.pop("context", None) + if context is None: + # New path: fetch context via the Execution API using stored identifiers. + dag_id = self.callback_kwargs.pop("dag_id", None) + run_id = self.callback_kwargs.pop("run_id", None) + deadline_id = self.callback_kwargs.pop("deadline_id", None) + deadline_time = self.callback_kwargs.pop("deadline_time", None) + + if dag_id and run_id: + context = await self._build_context(dag_id, run_id, deadline_id, deadline_time) + if accepts_context(callback) and context is not None: result = await callback(**self.callback_kwargs, context=context) else: diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 0501783b992d2..85076aca1b3a8 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1776,7 +1776,6 @@ def get_type_names(union_type): "GetAssetsByAlias", "GetAssetEventByAsset", "GetAssetEventByAssetAlias", - "GetDagRun", "GetPrevSuccessfulDagRun", "GetPreviousDagRun", "GetTaskBreadcrumbs", @@ -1815,7 +1814,6 @@ def get_type_names(union_type): "AssetResult", "AssetsByAliasResult", "AssetEventsResult", - "DagRunResult", "SentFDs", "StartupDetails", "TaskBreadcrumbsResult", diff --git a/airflow-core/tests/unit/models/test_deadline.py b/airflow-core/tests/unit/models/test_deadline.py index 94c6977ae0c16..9b967b8daad63 100644 --- a/airflow-core/tests/unit/models/test_deadline.py +++ b/airflow-core/tests/unit/models/test_deadline.py @@ -26,7 +26,6 @@ from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError -from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse from airflow.models import DagRun from airflow.models.deadline import Deadline, _fetch_from_db from airflow.providers.standard.operators.empty import EmptyOperator @@ -231,12 +230,19 @@ def test_handle_miss(self, dagrun, session): assert deadline_orm.missed callback_kwargs = deadline_orm.callback.data["kwargs"] - context = callback_kwargs.pop("context") - assert callback_kwargs == TEST_CALLBACK_KWARGS - assert context["deadline"]["id"] == deadline_orm.id - assert context["deadline"]["deadline_time"].timestamp() == deadline_orm.deadline_time.timestamp() - assert context["dag_run"] == DAGRunResponse.model_validate(dagrun).model_dump(mode="json") + # Verify that identifiers (not full context) are stored in kwargs + assert callback_kwargs["dag_id"] == dagrun.dag_id + assert callback_kwargs["run_id"] == dagrun.run_id + assert callback_kwargs["deadline_id"] == str(deadline_orm.id) + assert callback_kwargs["deadline_time"] == deadline_orm.deadline_time.isoformat() + + # The original user-provided kwargs should still be present + for key, value in TEST_CALLBACK_KWARGS.items(): + assert callback_kwargs[key] == value + + # No serialized "context" key should be stored + assert "context" not in callback_kwargs @pytest.mark.db_test diff --git a/airflow-core/tests/unit/triggers/test_callback.py b/airflow-core/tests/unit/triggers/test_callback.py index 99eca603323bb..0e077b3c0127b 100644 --- a/airflow-core/tests/unit/triggers/test_callback.py +++ b/airflow-core/tests/unit/triggers/test_callback.py @@ -28,6 +28,10 @@ TEST_MESSAGE = "test_message" TEST_CALLBACK_PATH = "classpath.test_callback" TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": {"dag_run": "test"}} +TEST_DAG_ID = "test_dag" +TEST_RUN_ID = "test_run_2024" +TEST_DEADLINE_ID = "abc-123" +TEST_DEADLINE_TIME = "2024-01-01T00:00:00+00:00" class ExampleAsyncNotifier(BaseNotifier): @@ -132,3 +136,84 @@ async def test_run_failure(self, trigger, mock_import_string): mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS) assert failure_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.FAILED assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in ["raise", "RuntimeError", exc_msg]) + + @pytest.mark.asyncio + async def test_run_fetches_context_via_execution_api(self, mock_import_string): + """When kwargs contain identifiers (dag_id, run_id) but no 'context', fetch via API.""" + from airflow.sdk.execution_time.comms import DagRunResult + + mock_callback = mock.AsyncMock(return_value="done") + mock_import_string.return_value = mock_callback + + trigger = CallbackTrigger( + callback_path=TEST_CALLBACK_PATH, + callback_kwargs={ + "message": TEST_MESSAGE, + "dag_id": TEST_DAG_ID, + "run_id": TEST_RUN_ID, + "deadline_id": TEST_DEADLINE_ID, + "deadline_time": TEST_DEADLINE_TIME, + }, + ) + + # Create a mock DagRunResult response + mock_dag_run_result = DagRunResult( + dag_id=TEST_DAG_ID, + run_id=TEST_RUN_ID, + logical_date="2024-01-01T00:00:00+00:00", + data_interval_start="2024-01-01T00:00:00+00:00", + data_interval_end="2024-01-02T00:00:00+00:00", + run_after="2024-01-01T00:00:00+00:00", + run_type="manual", + state="running", + conf={"key": "value"}, + consumed_asset_events=[], + ) + + with mock.patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True) as mock_comms: + mock_comms.asend = mock.AsyncMock(return_value=mock_dag_run_result) + + trigger_gen = trigger.run() + running_event = await anext(trigger_gen) + assert running_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.RUNNING + + success_event = await anext(trigger_gen) + assert success_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS + + # Verify the callback was called with context (since AsyncMock accepts context) + call_kwargs = mock_callback.call_args[1] + assert call_kwargs["message"] == TEST_MESSAGE + assert "context" in call_kwargs + context = call_kwargs["context"] + assert context["dag_id"] == TEST_DAG_ID + assert context["run_id"] == TEST_RUN_ID + assert context["deadline"]["id"] == TEST_DEADLINE_ID + assert context["deadline"]["deadline_time"] == TEST_DEADLINE_TIME + assert context["conf"] == {"key": "value"} + + @pytest.mark.asyncio + async def test_run_backward_compat_with_stored_context(self, mock_import_string): + """When kwargs contain 'context' directly (pre-upgrade callbacks), use it as-is.""" + mock_callback = mock.AsyncMock(return_value="done") + mock_import_string.return_value = mock_callback + + legacy_context = {"dag_run": {"dag_id": "old_dag"}, "deadline": {"id": "old-id"}} + trigger = CallbackTrigger( + callback_path=TEST_CALLBACK_PATH, + callback_kwargs={ + "message": TEST_MESSAGE, + "context": legacy_context, + }, + ) + + trigger_gen = trigger.run() + running_event = await anext(trigger_gen) + assert running_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.RUNNING + + success_event = await anext(trigger_gen) + assert success_event.payload[PAYLOAD_STATUS_KEY] == CallbackState.SUCCESS + + # Verify the callback was called with the legacy context directly + call_kwargs = mock_callback.call_args[1] + assert call_kwargs["context"] == legacy_context + assert call_kwargs["message"] == TEST_MESSAGE diff --git a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py index 94d84193192db..8d49925942146 100644 --- a/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py @@ -31,8 +31,10 @@ from airflow.sdk._shared.module_loading import accepts_context, accepts_keyword_args from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ( + DagRunResult, ErrorResponse, GetConnection, + GetDagRun, GetVariable, MaskSecret, ) @@ -69,10 +71,10 @@ class _BundleInfoLike(Protocol): # The set of messages that a callback subprocess can send to the supervisor. -# This is a minimal subset of ToSupervisor: read-only access to Connections -# and Variables, plus MaskSecret for the secrets masker. +# This is a minimal subset of ToSupervisor: read-only access to Connections, +# Variables, and DagRun details, plus MaskSecret for the secrets masker. CallbackToSupervisor = Annotated[ - GetConnection | GetVariable | MaskSecret, + GetConnection | GetDagRun | GetVariable | MaskSecret, Field(discriminator="type"), ] @@ -282,6 +284,9 @@ def _handle_request(self, msg: CallbackToSupervisor, log: FilteringBoundLogger, if isinstance(msg, GetConnection): resp, dump_opts = handle_get_connection(self.client, msg) + elif isinstance(msg, GetDagRun): + dr_resp = self.client.dag_runs.get_detail(msg.dag_id, msg.run_id) + resp = DagRunResult.from_api_response(dr_resp) elif isinstance(msg, GetVariable): resp, dump_opts = handle_get_variable(self.client, msg) elif isinstance(msg, MaskSecret): diff --git a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py index 8cb9fdcc8167a..2f29af75f1e67 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py @@ -21,6 +21,7 @@ import socket from dataclasses import dataclass +from datetime import datetime, timezone from operator import attrgetter from typing import Any from unittest.mock import patch @@ -28,10 +29,12 @@ import pytest import structlog +from airflow.sdk.api.datamodels._generated import DagRun, DagRunState, DagRunType from airflow.sdk.execution_time.callback_supervisor import CallbackSubprocess, execute_callback from airflow.sdk.execution_time.comms import ( ConnectionResult, GetConnection, + GetDagRun, GetVariable, MaskSecret, VariableResult, @@ -171,6 +174,22 @@ class RequestCase: ), mask_secret_args=("secret",), ), + RequestCase( + message=GetDagRun(dag_id="test_dag", run_id="test_run_1"), + test_id="get_dag_run", + client_mock=ClientMock( + method_path="dag_runs.get_detail", + args=("test_dag", "test_run_1"), + response=DagRun( + dag_id="test_dag", + run_id="test_run_1", + run_after=datetime(2024, 1, 1, tzinfo=timezone.utc), + run_type=DagRunType.MANUAL, + state=DagRunState.RUNNING, + consumed_asset_events=[], + ), + ), + ), RequestCase( message=GetVariable(key="test_key"), test_id="get_variable",