From b644ce6d2e3149c090706f6b949fb2f1e2ba9978 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 6 May 2026 14:26:43 +0530 Subject: [PATCH 01/10] AIP-103: Adding periodic task state garbage collection and retention support --- airflow-core/src/airflow/cli/cli_config.py | 18 +++++ .../cli/commands/state_store_command.py | 57 ++++++++++++++ .../src/airflow/config_templates/config.yml | 17 +++++ .../src/airflow/jobs/scheduler_job_runner.py | 36 ++++++++- ...0_add_task_state_and_asset_state_tables.py | 5 ++ airflow-core/src/airflow/models/task_state.py | 6 ++ airflow-core/src/airflow/state/metastore.py | 75 ++++++++++++++++++- .../cli/commands/test_state_store_command.py | 48 ++++++++++++ .../tests/unit/state/test_metastore.py | 73 ++++++++++++++++++ .../src/airflow_shared/state/__init__.py | 11 +++ 10 files changed, 344 insertions(+), 2 deletions(-) create mode 100644 airflow-core/src/airflow/cli/commands/state_store_command.py create mode 100644 airflow-core/tests/unit/cli/commands/test_state_store_command.py diff --git a/airflow-core/src/airflow/cli/cli_config.py b/airflow-core/src/airflow/cli/cli_config.py index a09851e5e24ff..1bccc7df09988 100644 --- a/airflow-core/src/airflow/cli/cli_config.py +++ b/airflow-core/src/airflow/cli/cli_config.py @@ -1519,6 +1519,19 @@ class GroupCommand(NamedTuple): args=(ARG_OUTPUT, ARG_VERBOSE), ), ) +STATE_STORE_COMMANDS = ( + ActionCommand( + name="cleanup", + help="Remove expired task state rows via the configured state backend", + description=( + "Reads [state_store] default_retention_days from config and deletes task_state rows " + "older than the configured threshold. Use --dry-run to preview without deleting." + ), + func=lazy_load_command("airflow.cli.commands.state_store_command.cleanup"), + args=(ARG_DB_DRY_RUN, ARG_VERBOSE), + ), +) + DB_COMMANDS = ( ActionCommand( name="check-migrations", @@ -2102,6 +2115,11 @@ class GroupCommand(NamedTuple): help="Display providers", subcommands=PROVIDERS_COMMANDS, ), + GroupCommand( + name="state-store", + help="Manage task and asset state storage", + subcommands=STATE_STORE_COMMANDS, + ), ActionCommand( name="rotate-fernet-key", func=lazy_load_command("airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key"), diff --git a/airflow-core/src/airflow/cli/commands/state_store_command.py b/airflow-core/src/airflow/cli/commands/state_store_command.py new file mode 100644 index 0000000000000..9080998eed330 --- /dev/null +++ b/airflow-core/src/airflow/cli/commands/state_store_command.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging + +log = logging.getLogger(__name__) + + +def cleanup(args) -> None: + """Remove expired task state rows via the configured state backend.""" + from airflow.state import get_state_backend + from airflow.state.metastore import MetastoreStateBackend + + backend = get_state_backend() + + if args.dry_run: + if isinstance(backend, MetastoreStateBackend): + summary = backend._dry_run_summary() + stale, expired = summary["stale"], summary["expired"] + total = len(stale) + len(expired) + if not total: + print("Nothing to delete.") + return + print(f"Would delete {total} task state row(s):\n") + if stale: + print(f" Older than retention period ({len(stale)}):") + for dag_id, run_id, task_id, map_index, key in stale: + print( + f" DAG {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}" + ) + if expired: + print(f"\n Per-key expiry reached ({len(expired)}):") + for dag_id, run_id, task_id, map_index, key in expired: + print( + f" DAG {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}" + ) + else: + print("Custom backend configured — cannot preview rows.") + return + + log.info("Running state store cleanup") + backend.cleanup() diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index c37989b9b7486..81fb5e9b2beb0 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -3014,6 +3014,23 @@ state_store: type: string example: "mypackage.state.CustomStateBackend" default: "airflow.state.metastore.MetastoreStateBackend" + default_retention_days: + description: | + Number of days to retain task_state rows after their last update. + Rows older than this are removed by the scheduler's periodic cleanup. + This config does not affect asset_state rows. + Set to 0 to disable time-based cleanup entirely. + version_added: 3.3.0 + type: integer + example: "7" + default: "30" + clear_on_success: + description: | + If True, task state is automatically cleared when a task instance succeeds. + version_added: 3.3.0 + type: boolean + example: "True" + default: "False" profiling: description: | diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 80897213c18b5..e0db564ee2f89 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -33,7 +33,22 @@ from itertools import groupby from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import CTE, and_, case, delete, exists, func, inspect, or_, select, text, tuple_, update +from sqlalchemy import ( + CTE, + and_, + case, + delete, + delete as _delete, + exists, + func, + inspect, + or_, + select, + select as _select, + text, + tuple_, + update, +) from sqlalchemy.exc import DBAPIError, OperationalError from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, selectinload from sqlalchemy.sql import expression @@ -58,10 +73,12 @@ from airflow.models import Deadline, Log from airflow.models.asset import ( AssetActive, + AssetActive as _AssetActive, AssetAliasModel, AssetDagRunQueue, AssetEvent, AssetModel, + AssetModel as _AssetModel, AssetPartitionDagRun, AssetWatcherModel, DagScheduleAssetAliasReference, @@ -70,6 +87,7 @@ TaskInletAssetReference, TaskOutletAssetReference, ) +from airflow.models.asset_state import AssetStateModel from airflow.models.backfill import Backfill, BackfillDagRun from airflow.models.callback import Callback, CallbackType, ExecutorCallback from airflow.models.dag import DagModel @@ -3080,6 +3098,7 @@ def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None: self._orphan_unreferenced_assets(orphan_query, session=session) self._activate_referenced_assets(activate_query, session=session) + self._cleanup_orphaned_asset_state(session=session) @staticmethod def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) -> None: @@ -3188,6 +3207,21 @@ def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]: session.add(warning) existing_warned_dag_ids.add(warning.dag_id) + @staticmethod + def _cleanup_orphaned_asset_state(*, session: Session) -> None: + """ + Delete asset_state rows for assets no longer active in any DAG. + + When _orphan_unreferenced_assets removes an asset from asset_active, its + asset_state rows become unreachable — no task can write to them anymore. + This runs in the same pass as asset orphanage to keep the table clean. + """ + active_asset_ids = _select(_AssetModel.id).join( + _AssetActive, + (_AssetActive.name == _AssetModel.name) & (_AssetActive.uri == _AssetModel.uri), + ) + session.execute(_delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids))) + def _executor_to_workloads( self, workloads: Iterable[SchedulerWorkload], diff --git a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py index 7f852d05c6ca6..d40f9c90bc781 100644 --- a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py +++ b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py @@ -65,6 +65,11 @@ def upgrade(): sa.Column("run_id", StringID(), nullable=False), sa.Column("value", sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"), nullable=False), sa.Column("updated_at", UtcDateTime(), nullable=False), + # Optional early-expiry override. When set, GC deletes this row when expires_at < now() + # even if updated_at is recent. NULL means no early expiry — the row is still cleaned + # up by the global updated_at + default_retention_days check. Populated via + # task_state.set(retention_days=N) for keys that should expire sooner than the default. + sa.Column("expires_at", UtcDateTime(), nullable=True), sa.ForeignKeyConstraint( ["dag_run_id"], ["dag_run.id"], name="task_state_dag_run_fkey", ondelete="CASCADE" ), diff --git a/airflow-core/src/airflow/models/task_state.py b/airflow-core/src/airflow/models/task_state.py index dbc17e3b06950..cca7ce99de23a 100644 --- a/airflow-core/src/airflow/models/task_state.py +++ b/airflow-core/src/airflow/models/task_state.py @@ -49,6 +49,12 @@ class TaskStateModel(Base): value: Mapped[str] = mapped_column(Text().with_variant(MEDIUMTEXT, "mysql"), nullable=False) updated_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + # Optional override for early expiry. When set, garbage collection deletes this row when + # expires_at < now(), even if updated_at is recent. NULL means no early expiry — + # the row is still cleaned up by the global `updated_at + default_retention_days` check. + # Populated via task_state.set(retention_days=N) for keys that should expire differently + # than the deployment wide default. + expires_at: Mapped[datetime | None] = mapped_column(UtcDateTime, nullable=True) __table_args__ = ( PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"), diff --git a/airflow-core/src/airflow/state/metastore.py b/airflow-core/src/airflow/state/metastore.py index 3382dad81fc65..84a687e890488 100644 --- a/airflow-core/src/airflow/state/metastore.py +++ b/airflow-core/src/airflow/state/metastore.py @@ -17,17 +17,20 @@ # under the License. from __future__ import annotations +from datetime import timedelta from typing import TYPE_CHECKING +import structlog from sqlalchemy import delete, select from airflow._shared.state import AssetScope, BaseStateBackend, StateScope, TaskScope from airflow._shared.timezones import timezone +from airflow.configuration import conf from airflow.models.asset_state import AssetStateModel from airflow.models.dagrun import DagRun from airflow.models.task_state import TaskStateModel from airflow.typing_compat import assert_never -from airflow.utils.session import NEW_SESSION, create_session_async, provide_session +from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session from airflow.utils.sqlalchemy import get_dialect_name if TYPE_CHECKING: @@ -38,6 +41,9 @@ from sqlalchemy.orm import Session +log = structlog.get_logger(__name__) + + def _build_upsert_stmt( dialect: str | None, model: type, @@ -252,6 +258,73 @@ def _clear_asset_state(self, scope: AssetScope, *, session: Session) -> None: ) ) + def cleanup(self) -> None: + """ + Remove expired task state rows. + + Reads ``[state_store] default_retention_days`` from config for the time-based threshold. + Set to 0 to disable time-based cleanup (expires_at cleanup still runs). + + Two passes: + a. Rows where updated_at < now() - default_retention_days (global retention) + b. Rows where expires_at < now() (per-key early expiry set by the operator) + + Asset state orphan cleanup is handled separately by the scheduler's + _cleanup_orphaned_asset_state(), which runs alongside asset deregistration. + """ + retention_days = conf.getint("state_store", "default_retention_days") + now = timezone.utcnow() + older_than = now - timedelta(days=retention_days) if retention_days > 0 else None + with create_session() as session: + if older_than: + result = session.execute( # type: ignore[assignment] + delete(TaskStateModel) + .where(TaskStateModel.updated_at < older_than) + .execution_options(synchronize_session="fetch") + ) + log.info( + "Deleted stale task_state rows", + rows_deleted=getattr(result, "rowcount", None), + older_than=older_than, + ) + result = session.execute( # type: ignore[assignment] + delete(TaskStateModel) + .where(TaskStateModel.expires_at.isnot(None), TaskStateModel.expires_at < now) + .execution_options(synchronize_session="fetch") + ) + log.info("Deleted expired task_state rows", rows_deleted=getattr(result, "rowcount", None)) + + def _dry_run_summary(self) -> dict[str, list]: + """ + Return rows that would be deleted by cleanup(), without deleting anything. + + Returns a dict with keys 'stale' and 'expired', each containing a list of + (dag_id, run_id, task_id, map_index, key) tuples. + """ + retention_days = conf.getint("state_store", "default_retention_days") + now = timezone.utcnow() + older_than = now - timedelta(days=retention_days) if retention_days > 0 else None + + cols = ( + TaskStateModel.dag_id, + TaskStateModel.run_id, + TaskStateModel.task_id, + TaskStateModel.map_index, + TaskStateModel.key, + ) + + with create_session() as session: + stale = ( + session.execute(select(*cols).where(TaskStateModel.updated_at < older_than)).all() + if older_than + else [] + ) + expired = session.execute( + select(*cols).where(TaskStateModel.expires_at.isnot(None), TaskStateModel.expires_at < now) + ).all() + + return {"stale": list(stale), "expired": list(expired)} + async def _aget_task_state(self, scope: TaskScope, key: str, *, session: AsyncSession) -> str | None: row = await session.scalar( select(TaskStateModel).where( diff --git a/airflow-core/tests/unit/cli/commands/test_state_store_command.py b/airflow-core/tests/unit/cli/commands/test_state_store_command.py new file mode 100644 index 0000000000000..3086cb4b141c6 --- /dev/null +++ b/airflow-core/tests/unit/cli/commands/test_state_store_command.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from argparse import Namespace +from unittest import mock +from unittest.mock import MagicMock, patch + +from airflow.cli.commands.state_store_command import cleanup +from airflow.state.metastore import MetastoreStateBackend + + +class TestStateStoreCleanupCommand: + def test_cleanup_calls_backend(self): + args = Namespace(dry_run=False, verbose=False) + with mock.patch("airflow.state.get_state_backend") as mock_get_backend: + mock_backend = MagicMock() + mock_get_backend.return_value = mock_backend + + cleanup(args) + + mock_backend.cleanup.assert_called_once_with() + + def test_dry_run_does_not_call_backend(self, capsys): + args = Namespace(dry_run=True, verbose=False) + backend = MetastoreStateBackend() + with ( + mock.patch("airflow.state.get_state_backend", return_value=backend), + patch.object(backend, "_dry_run_summary", return_value={"stale": [], "expired": []}), + ): + cleanup(args) + + captured = capsys.readouterr() + assert "Nothing to delete" in captured.out diff --git a/airflow-core/tests/unit/state/test_metastore.py b/airflow-core/tests/unit/state/test_metastore.py index 98993d7133c41..b536803dcfc62 100644 --- a/airflow-core/tests/unit/state/test_metastore.py +++ b/airflow-core/tests/unit/state/test_metastore.py @@ -17,13 +17,16 @@ # under the License. from __future__ import annotations +from datetime import timedelta from typing import TYPE_CHECKING import pytest from sqlalchemy import select from airflow._shared.timezones import timezone +from airflow.configuration import conf from airflow.models.asset import AssetModel +from airflow.models.asset_state import AssetStateModel from airflow.models.dagrun import DagRun, DagRunType from airflow.models.task_state import TaskStateModel from airflow.state import AssetScope, TaskScope, resolve_state_backend @@ -234,6 +237,52 @@ def test_clear_with_all_map_indices_flag_wipes_wide( assert backend.get(scope0, "job_id", session=session) is None assert backend.get(scope1, "job_id", session=session) is None + def test_cleanup_removes_expired_rows( + self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun + ): + scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) + backend.set(scope, "old_key", "old_value", session=session) + session.flush() + + old_row = session.scalar( + select(TaskStateModel).where(TaskStateModel.dag_id == DAG_ID, TaskStateModel.key == "old_key") + ) + assert old_row is not None + old_row.updated_at = timezone.utcnow() - timedelta(days=40) + session.flush() + + backend.set(scope, "new_key", "new_value", session=session) + session.flush() + session.commit() + + backend.cleanup() + + session.expire_all() + assert session.scalar(select(TaskStateModel).where(TaskStateModel.key == "old_key")) is None + assert session.scalar(select(TaskStateModel).where(TaskStateModel.key == "new_key")) is not None + + def test_cleanup_removes_expires_at_rows( + self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun + ): + scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) + backend.set(scope, "short_lived", "value", session=session) + session.flush() + + row = session.scalar( + select(TaskStateModel).where(TaskStateModel.dag_id == DAG_ID, TaskStateModel.key == "short_lived") + ) + assert row is not None + row.expires_at = timezone.utcnow() - timedelta(hours=1) + session.flush() + session.commit() + + backend.cleanup() + + session.expire_all() + + # cleaned up via expires_at, even though updated_at is recent + assert session.scalar(select(TaskStateModel).where(TaskStateModel.key == "short_lived")) is None + class TestMetastoreStateBackendAssetScope: def test_get_returns_none_for_missing_key( @@ -306,6 +355,19 @@ def test_different_assets_are_isolated( assert backend.get(scope2, "watermark", session=session) is None + def test_cleanup_does_not_touch_asset_state( + self, session: Session, backend: MetastoreStateBackend, asset: AssetModel + ): + scope = AssetScope(asset_id=asset.id) + backend.set(scope, "watermark", "2026-01-01", session=session) + session.flush() + session.commit() + + backend.cleanup() + + session.expire_all() + assert session.scalar(select(AssetStateModel).where(AssetStateModel.asset_id == asset.id)) is not None + @pytest.mark.asyncio(loop_scope="class") class TestMetastoreStateBackendAsync: @@ -380,6 +442,17 @@ async def test_aset_task_raises_for_missing_dag_run(self, backend: MetastoreStat await backend.aset(scope, "job_id", "app_async") +class TestStateStoreConfig: + def test_defaults(self): + assert conf.getint("state_store", "default_retention_days") == 30 + assert conf.getboolean("state_store", "clear_on_success") is False + + @conf_vars({("state_store", "default_retention_days"): "7", ("state_store", "clear_on_success"): "True"}) + def test_overrides(self): + assert conf.getint("state_store", "default_retention_days") == 7 + assert conf.getboolean("state_store", "clear_on_success") is True + + class TestResolveStateBackend: @conf_vars({("state_store", "backend"): "airflow.state.metastore.MetastoreStateBackend"}) def test_resolve_returns_configured_backend(self): diff --git a/shared/state/src/airflow_shared/state/__init__.py b/shared/state/src/airflow_shared/state/__init__.py index 463d9f378f315..1e03a381957d5 100644 --- a/shared/state/src/airflow_shared/state/__init__.py +++ b/shared/state/src/airflow_shared/state/__init__.py @@ -122,3 +122,14 @@ async def aclear(self, scope: StateScope, *, all_map_indices: bool = False) -> N scope are cleared. Pass ``all_map_indices=True`` to wipe state across every mapped instance of the task. For ``AssetScope`` the flag has no effect. """ + + def cleanup(self) -> None: + """ + Remove expired and orphaned state records. + + This is a no-op by default. Custom backends override this to implement their own + retention policy. The backend is responsible for reading any relevant config (e.g. + ``[state_store] default_retention_days``) and deciding what to delete. + Airflow does not call this from any standard job — the scheduler triggers it via + ``call_regular_interval`` for the default backend. + """ From cdc423703e23306d5bfd6df6cf5f452d41e0d4ac Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Thu, 7 May 2026 18:23:38 +0530 Subject: [PATCH 02/10] comments from jason --- .../cli/commands/state_store_command.py | 4 ++ .../src/airflow/config_templates/config.yml | 11 ++-- ...0_add_task_state_and_asset_state_tables.py | 4 ++ airflow-core/src/airflow/models/task_state.py | 2 + airflow-core/src/airflow/state/metastore.py | 58 +++++++++++-------- .../tests/unit/state/test_metastore.py | 33 ++++++++++- 6 files changed, 81 insertions(+), 31 deletions(-) diff --git a/airflow-core/src/airflow/cli/commands/state_store_command.py b/airflow-core/src/airflow/cli/commands/state_store_command.py index 9080998eed330..c15c1d2277bbc 100644 --- a/airflow-core/src/airflow/cli/commands/state_store_command.py +++ b/airflow-core/src/airflow/cli/commands/state_store_command.py @@ -20,6 +20,10 @@ log = logging.getLogger(__name__) +# Other state operations (list, get, delete per key) will be added here once the +# Core API endpoints (PR 6) land. For now, inspection is available via the REST +# API and the Task Instance detail panel in the UI. + def cleanup(args) -> None: """Remove expired task state rows via the configured state backend.""" diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 81fb5e9b2beb0..69a33d90eeed8 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -3024,13 +3024,14 @@ state_store: type: integer example: "7" default: "30" - clear_on_success: + state_cleanup_batch_size: description: | - If True, task state is automatically cleared when a task instance succeeds. + Number of rows deleted per batch during cleanup. Defaults to 0 (no batching). + Tune this on deployments with large task_state tables to improve performance per transaction. version_added: 3.3.0 - type: boolean - example: "True" - default: "False" + type: integer + example: "10000" + default: "0" profiling: description: | diff --git a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py index d40f9c90bc781..30297edca7423 100644 --- a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py +++ b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py @@ -79,11 +79,15 @@ def upgrade(): batch_op.create_index( "idx_task_state_lookup", ["dag_id", "run_id", "task_id", "map_index"], unique=False ) + batch_op.create_index("idx_task_state_updated_at", ["updated_at"], unique=False) + batch_op.create_index("idx_task_state_expires_at", ["expires_at"], unique=False) def downgrade(): """Unapply add task_state and asset_state tables.""" with op.batch_alter_table("task_state", schema=None) as batch_op: + batch_op.drop_index("idx_task_state_expires_at") + batch_op.drop_index("idx_task_state_updated_at") batch_op.drop_index("idx_task_state_lookup") op.drop_table("task_state") diff --git a/airflow-core/src/airflow/models/task_state.py b/airflow-core/src/airflow/models/task_state.py index cca7ce99de23a..4dca421df3344 100644 --- a/airflow-core/src/airflow/models/task_state.py +++ b/airflow-core/src/airflow/models/task_state.py @@ -65,4 +65,6 @@ class TaskStateModel(Base): ondelete="CASCADE", ), Index("idx_task_state_lookup", "dag_id", "run_id", "task_id", "map_index"), + Index("idx_task_state_updated_at", "updated_at"), + Index("idx_task_state_expires_at", "expires_at"), ) diff --git a/airflow-core/src/airflow/state/metastore.py b/airflow-core/src/airflow/state/metastore.py index 84a687e890488..54c356f36dda7 100644 --- a/airflow-core/src/airflow/state/metastore.py +++ b/airflow-core/src/airflow/state/metastore.py @@ -22,6 +22,7 @@ import structlog from sqlalchemy import delete, select +from sqlalchemy.sql.expression import tuple_ from airflow._shared.state import AssetScope, BaseStateBackend, StateScope, TaskScope from airflow._shared.timezones import timezone @@ -262,37 +263,48 @@ def cleanup(self) -> None: """ Remove expired task state rows. - Reads ``[state_store] default_retention_days`` from config for the time-based threshold. - Set to 0 to disable time-based cleanup (expires_at cleanup still runs). + Reads ``[state_store] default_retention_days`` and ``[state_store] state_cleanup_batch_size`` + from config. Each pass runs in its own transaction so partial progress is committed even if a + later pass fails. Each pass is batched to avoid long-running locks on the table. Two passes: a. Rows where updated_at < now() - default_retention_days (global retention) b. Rows where expires_at < now() (per-key early expiry set by the operator) - - Asset state orphan cleanup is handled separately by the scheduler's - _cleanup_orphaned_asset_state(), which runs alongside asset deregistration. """ retention_days = conf.getint("state_store", "default_retention_days") + batch_size = conf.getint("state_store", "state_cleanup_batch_size") now = timezone.utcnow() older_than = now - timedelta(days=retention_days) if retention_days > 0 else None - with create_session() as session: - if older_than: - result = session.execute( # type: ignore[assignment] - delete(TaskStateModel) - .where(TaskStateModel.updated_at < older_than) - .execution_options(synchronize_session="fetch") - ) - log.info( - "Deleted stale task_state rows", - rows_deleted=getattr(result, "rowcount", None), - older_than=older_than, - ) - result = session.execute( # type: ignore[assignment] - delete(TaskStateModel) - .where(TaskStateModel.expires_at.isnot(None), TaskStateModel.expires_at < now) - .execution_options(synchronize_session="fetch") - ) - log.info("Deleted expired task_state rows", rows_deleted=getattr(result, "rowcount", None)) + + pk_cols = ( + TaskStateModel.dag_run_id, + TaskStateModel.task_id, + TaskStateModel.map_index, + TaskStateModel.key, + ) + + def _delete_batched(where_clause) -> int: + total = 0 + while True: + with create_session() as session: + pk_query = select(*pk_cols).where(where_clause) + if batch_size > 0: + pk_query = pk_query.limit(batch_size) + ids = session.execute(pk_query).all() + if not ids: + break + session.execute(delete(TaskStateModel).where(tuple_(*pk_cols).in_(ids))) + total += len(ids) + if batch_size <= 0 or len(ids) < batch_size: + break + return total + + if older_than: + deleted = _delete_batched(TaskStateModel.updated_at < older_than) + log.info("Deleted stale task_state rows", rows_deleted=deleted, older_than=older_than) + + deleted = _delete_batched((TaskStateModel.expires_at.isnot(None)) & (TaskStateModel.expires_at < now)) + log.info("Deleted expired task_state rows", rows_deleted=deleted) def _dry_run_summary(self) -> dict[str, list]: """ diff --git a/airflow-core/tests/unit/state/test_metastore.py b/airflow-core/tests/unit/state/test_metastore.py index b536803dcfc62..72ee0b3456460 100644 --- a/airflow-core/tests/unit/state/test_metastore.py +++ b/airflow-core/tests/unit/state/test_metastore.py @@ -283,6 +283,31 @@ def test_cleanup_removes_expires_at_rows( # cleaned up via expires_at, even though updated_at is recent assert session.scalar(select(TaskStateModel).where(TaskStateModel.key == "short_lived")) is None + @conf_vars({("state_store", "state_cleanup_batch_size"): "2"}) + def test_cleanup_batches_deletes(self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun): + from unittest.mock import patch + + import airflow.state.metastore as metastore_mod + + scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) + for key in ("k1", "k2", "k3", "k4", "k5"): + backend.set(scope, key, "v", session=session) + session.flush() + + session.execute( + TaskStateModel.__table__.update().values(updated_at=timezone.utcnow() - timedelta(days=40)) + ) + session.commit() + + with patch.object(metastore_mod, "create_session", wraps=metastore_mod.create_session) as mock_cs: + backend.cleanup() + # 5 rows, batch size = 2, so 3 batches and 1 check for expired_by row + assert mock_cs.call_count == 4 + + session.expire_all() + remaining = session.scalars(select(TaskStateModel).where(TaskStateModel.dag_id == DAG_ID)).all() + assert remaining == [] + class TestMetastoreStateBackendAssetScope: def test_get_returns_none_for_missing_key( @@ -445,12 +470,14 @@ async def test_aset_task_raises_for_missing_dag_run(self, backend: MetastoreStat class TestStateStoreConfig: def test_defaults(self): assert conf.getint("state_store", "default_retention_days") == 30 - assert conf.getboolean("state_store", "clear_on_success") is False + assert conf.getint("state_store", "state_cleanup_batch_size") == 0 - @conf_vars({("state_store", "default_retention_days"): "7", ("state_store", "clear_on_success"): "True"}) + @conf_vars( + {("state_store", "default_retention_days"): "7", ("state_store", "state_cleanup_batch_size"): "50"} + ) def test_overrides(self): assert conf.getint("state_store", "default_retention_days") == 7 - assert conf.getboolean("state_store", "clear_on_success") is True + assert conf.getint("state_store", "state_cleanup_batch_size") == 50 class TestResolveStateBackend: From df379c50e66739981b03e61fee7afcc803946a38 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 8 May 2026 12:06:34 +0530 Subject: [PATCH 03/10] handling comments from ash --- airflow-core/src/airflow/cli/cli_config.py | 2 +- .../cli/commands/state_store_command.py | 27 ++---- ...0_add_task_state_and_asset_state_tables.py | 10 +-- airflow-core/src/airflow/models/task_state.py | 15 ++-- airflow-core/src/airflow/state/metastore.py | 88 ++++++++----------- .../cli/commands/test_state_store_command.py | 2 +- .../tests/unit/state/test_metastore.py | 61 ++++++++++--- 7 files changed, 105 insertions(+), 100 deletions(-) diff --git a/airflow-core/src/airflow/cli/cli_config.py b/airflow-core/src/airflow/cli/cli_config.py index 1bccc7df09988..d4d4611243bb0 100644 --- a/airflow-core/src/airflow/cli/cli_config.py +++ b/airflow-core/src/airflow/cli/cli_config.py @@ -1522,7 +1522,7 @@ class GroupCommand(NamedTuple): STATE_STORE_COMMANDS = ( ActionCommand( name="cleanup", - help="Remove expired task state rows via the configured state backend", + help="Remove expired stored state via the configured state backend", description=( "Reads [state_store] default_retention_days from config and deletes task_state rows " "older than the configured threshold. Use --dry-run to preview without deleting." diff --git a/airflow-core/src/airflow/cli/commands/state_store_command.py b/airflow-core/src/airflow/cli/commands/state_store_command.py index c15c1d2277bbc..6aa5a83c200cd 100644 --- a/airflow-core/src/airflow/cli/commands/state_store_command.py +++ b/airflow-core/src/airflow/cli/commands/state_store_command.py @@ -20,9 +20,7 @@ log = logging.getLogger(__name__) -# Other state operations (list, get, delete per key) will be added here once the -# Core API endpoints (PR 6) land. For now, inspection is available via the REST -# API and the Task Instance detail panel in the UI. +# Other state operations (list, get, delete per key) will be added here in the future. def cleanup(args) -> None: @@ -35,24 +33,15 @@ def cleanup(args) -> None: if args.dry_run: if isinstance(backend, MetastoreStateBackend): summary = backend._dry_run_summary() - stale, expired = summary["stale"], summary["expired"] - total = len(stale) + len(expired) - if not total: + expired = summary["expired"] + if not expired: print("Nothing to delete.") return - print(f"Would delete {total} task state row(s):\n") - if stale: - print(f" Older than retention period ({len(stale)}):") - for dag_id, run_id, task_id, map_index, key in stale: - print( - f" DAG {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}" - ) - if expired: - print(f"\n Per-key expiry reached ({len(expired)}):") - for dag_id, run_id, task_id, map_index, key in expired: - print( - f" DAG {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}" - ) + print(f"Would delete {len(expired)} task state row(s):\n") + for dag_id, run_id, task_id, map_index, key in expired: + print( + f" DAG {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}" + ) else: print("Custom backend configured — cannot preview rows.") return diff --git a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py index 30297edca7423..e64f80a05b119 100644 --- a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py +++ b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py @@ -57,6 +57,7 @@ def upgrade(): ) op.create_table( "task_state", + sa.Column("id", sa.Integer(), nullable=False, autoincrement=True), sa.Column("dag_run_id", sa.Integer(), nullable=False), sa.Column("task_id", StringID(), nullable=False), sa.Column("map_index", sa.Integer(), server_default="-1", nullable=False), @@ -65,21 +66,17 @@ def upgrade(): sa.Column("run_id", StringID(), nullable=False), sa.Column("value", sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"), nullable=False), sa.Column("updated_at", UtcDateTime(), nullable=False), - # Optional early-expiry override. When set, GC deletes this row when expires_at < now() - # even if updated_at is recent. NULL means no early expiry — the row is still cleaned - # up by the global updated_at + default_retention_days check. Populated via - # task_state.set(retention_days=N) for keys that should expire sooner than the default. sa.Column("expires_at", UtcDateTime(), nullable=True), sa.ForeignKeyConstraint( ["dag_run_id"], ["dag_run.id"], name="task_state_dag_run_fkey", ondelete="CASCADE" ), - sa.PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"), + sa.PrimaryKeyConstraint("id", name="task_state_pkey"), + sa.UniqueConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_uq"), ) with op.batch_alter_table("task_state", schema=None) as batch_op: batch_op.create_index( "idx_task_state_lookup", ["dag_id", "run_id", "task_id", "map_index"], unique=False ) - batch_op.create_index("idx_task_state_updated_at", ["updated_at"], unique=False) batch_op.create_index("idx_task_state_expires_at", ["expires_at"], unique=False) @@ -87,7 +84,6 @@ def downgrade(): """Unapply add task_state and asset_state tables.""" with op.batch_alter_table("task_state", schema=None) as batch_op: batch_op.drop_index("idx_task_state_expires_at") - batch_op.drop_index("idx_task_state_updated_at") batch_op.drop_index("idx_task_state_lookup") op.drop_table("task_state") diff --git a/airflow-core/src/airflow/models/task_state.py b/airflow-core/src/airflow/models/task_state.py index 4dca421df3344..72a7624eddd6e 100644 --- a/airflow-core/src/airflow/models/task_state.py +++ b/airflow-core/src/airflow/models/task_state.py @@ -19,7 +19,7 @@ from datetime import datetime -from sqlalchemy import ForeignKeyConstraint, Index, Integer, PrimaryKeyConstraint, String, Text +from sqlalchemy import ForeignKeyConstraint, Index, Integer, String, Text, UniqueConstraint from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.orm import Mapped, mapped_column @@ -39,10 +39,12 @@ class TaskStateModel(Base): __tablename__ = "task_state" - dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False, primary_key=True) - task_id: Mapped[str] = mapped_column(StringID(), nullable=False, primary_key=True) - map_index: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, server_default="-1") - key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + + dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False) + task_id: Mapped[str] = mapped_column(StringID(), nullable=False) + map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default="-1") + key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False) dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) run_id: Mapped[str] = mapped_column(StringID(), nullable=False) @@ -57,7 +59,7 @@ class TaskStateModel(Base): expires_at: Mapped[datetime | None] = mapped_column(UtcDateTime, nullable=True) __table_args__ = ( - PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"), + UniqueConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_uq"), ForeignKeyConstraint( ["dag_run_id"], ["dag_run.id"], @@ -65,6 +67,5 @@ class TaskStateModel(Base): ondelete="CASCADE", ), Index("idx_task_state_lookup", "dag_id", "run_id", "task_id", "map_index"), - Index("idx_task_state_updated_at", "updated_at"), Index("idx_task_state_expires_at", "expires_at"), ) diff --git a/airflow-core/src/airflow/state/metastore.py b/airflow-core/src/airflow/state/metastore.py index 54c356f36dda7..cdf595d2be634 100644 --- a/airflow-core/src/airflow/state/metastore.py +++ b/airflow-core/src/airflow/state/metastore.py @@ -17,12 +17,11 @@ # under the License. from __future__ import annotations -from datetime import timedelta +from datetime import datetime, timedelta from typing import TYPE_CHECKING import structlog from sqlalchemy import delete, select -from sqlalchemy.sql.expression import tuple_ from airflow._shared.state import AssetScope, BaseStateBackend, StateScope, TaskScope from airflow._shared.timezones import timezone @@ -45,6 +44,18 @@ log = structlog.get_logger(__name__) +def _compute_expires_at(now: datetime) -> datetime | None: + """ + Return the expiry timestamp for a new task state row based on config. + + Returns None if default_retention_days is 0 (never expires). + """ + retention_days = conf.getint("state_store", "default_retention_days") + if retention_days <= 0: + return None + return now + timedelta(days=retention_days) + + def _build_upsert_stmt( dialect: str | None, model: type, @@ -183,6 +194,7 @@ def _set_task_state(self, scope: TaskScope, key: str, value: str, *, session: Se if dag_run_id is None: raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} run_id={scope.run_id!r}") now = timezone.utcnow() + expires_at = _compute_expires_at(now) values = dict( dag_run_id=dag_run_id, dag_id=scope.dag_id, @@ -192,13 +204,14 @@ def _set_task_state(self, scope: TaskScope, key: str, value: str, *, session: Se key=key, value=value, updated_at=now, + expires_at=expires_at, ) stmt = _build_upsert_stmt( get_dialect_name(session), TaskStateModel, ["dag_run_id", "task_id", "map_index", "key"], values, - dict(value=value, updated_at=now), + dict(value=value, updated_at=now, expires_at=expires_at), ) session.execute(stmt) @@ -263,60 +276,36 @@ def cleanup(self) -> None: """ Remove expired task state rows. - Reads ``[state_store] default_retention_days`` and ``[state_store] state_cleanup_batch_size`` - from config. Each pass runs in its own transaction so partial progress is committed even if a - later pass fails. Each pass is batched to avoid long-running locks on the table. - - Two passes: - a. Rows where updated_at < now() - default_retention_days (global retention) - b. Rows where expires_at < now() (per-key early expiry set by the operator) + ``expires_at`` is set at write time on every ``set()`` call, so cleanup is a single + ``WHERE expires_at < now()`` pass. Rows with ``expires_at=NULL`` (default_retention_days=0) + are never deleted. Batching is configurable via ``[state_store] state_cleanup_batch_size``. """ - retention_days = conf.getint("state_store", "default_retention_days") batch_size = conf.getint("state_store", "state_cleanup_batch_size") now = timezone.utcnow() - older_than = now - timedelta(days=retention_days) if retention_days > 0 else None - - pk_cols = ( - TaskStateModel.dag_run_id, - TaskStateModel.task_id, - TaskStateModel.map_index, - TaskStateModel.key, - ) def _delete_batched(where_clause) -> int: total = 0 - while True: - with create_session() as session: - pk_query = select(*pk_cols).where(where_clause) + with create_session() as session: + while True: + id_query = select(TaskStateModel.id).where(where_clause) if batch_size > 0: - pk_query = pk_query.limit(batch_size) - ids = session.execute(pk_query).all() + id_query = id_query.limit(batch_size) + ids = session.scalars(id_query).all() if not ids: break - session.execute(delete(TaskStateModel).where(tuple_(*pk_cols).in_(ids))) + session.execute(delete(TaskStateModel).where(TaskStateModel.id.in_(ids))) + session.commit() total += len(ids) - if batch_size <= 0 or len(ids) < batch_size: - break + if batch_size <= 0 or len(ids) < batch_size: + break return total - if older_than: - deleted = _delete_batched(TaskStateModel.updated_at < older_than) - log.info("Deleted stale task_state rows", rows_deleted=deleted, older_than=older_than) - - deleted = _delete_batched((TaskStateModel.expires_at.isnot(None)) & (TaskStateModel.expires_at < now)) + deleted = _delete_batched(TaskStateModel.expires_at < now) log.info("Deleted expired task_state rows", rows_deleted=deleted) def _dry_run_summary(self) -> dict[str, list]: - """ - Return rows that would be deleted by cleanup(), without deleting anything. - - Returns a dict with keys 'stale' and 'expired', each containing a list of - (dag_id, run_id, task_id, map_index, key) tuples. - """ - retention_days = conf.getint("state_store", "default_retention_days") + """Return rows that would be deleted by cleanup() without deleting anything.""" now = timezone.utcnow() - older_than = now - timedelta(days=retention_days) if retention_days > 0 else None - cols = ( TaskStateModel.dag_id, TaskStateModel.run_id, @@ -324,18 +313,9 @@ def _dry_run_summary(self) -> dict[str, list]: TaskStateModel.map_index, TaskStateModel.key, ) - with create_session() as session: - stale = ( - session.execute(select(*cols).where(TaskStateModel.updated_at < older_than)).all() - if older_than - else [] - ) - expired = session.execute( - select(*cols).where(TaskStateModel.expires_at.isnot(None), TaskStateModel.expires_at < now) - ).all() - - return {"stale": list(stale), "expired": list(expired)} + expired = session.execute(select(*cols).where(TaskStateModel.expires_at < now)).all() + return {"expired": list(expired)} async def _aget_task_state(self, scope: TaskScope, key: str, *, session: AsyncSession) -> str | None: row = await session.scalar( @@ -361,6 +341,7 @@ async def _aset_task_state( if dag_run_id is None: raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} run_id={scope.run_id!r}") now = timezone.utcnow() + expires_at = _compute_expires_at(now) values = dict( dag_run_id=dag_run_id, dag_id=scope.dag_id, @@ -370,6 +351,7 @@ async def _aset_task_state( key=key, value=value, updated_at=now, + expires_at=expires_at, ) # get_dialect_name expects a sync Session; sync_session is the underlying Session the async wrapper delegates to stmt = _build_upsert_stmt( @@ -377,7 +359,7 @@ async def _aset_task_state( TaskStateModel, ["dag_run_id", "task_id", "map_index", "key"], values, - dict(value=value, updated_at=now), + dict(value=value, updated_at=now, expires_at=expires_at), ) await session.execute(stmt) diff --git a/airflow-core/tests/unit/cli/commands/test_state_store_command.py b/airflow-core/tests/unit/cli/commands/test_state_store_command.py index 3086cb4b141c6..a6ad669181156 100644 --- a/airflow-core/tests/unit/cli/commands/test_state_store_command.py +++ b/airflow-core/tests/unit/cli/commands/test_state_store_command.py @@ -40,7 +40,7 @@ def test_dry_run_does_not_call_backend(self, capsys): backend = MetastoreStateBackend() with ( mock.patch("airflow.state.get_state_backend", return_value=backend), - patch.object(backend, "_dry_run_summary", return_value={"stale": [], "expired": []}), + patch.object(backend, "_dry_run_summary", return_value={"expired": []}), ): cleanup(args) diff --git a/airflow-core/tests/unit/state/test_metastore.py b/airflow-core/tests/unit/state/test_metastore.py index 72ee0b3456460..2407f21d51bc4 100644 --- a/airflow-core/tests/unit/state/test_metastore.py +++ b/airflow-core/tests/unit/state/test_metastore.py @@ -17,11 +17,13 @@ # under the License. from __future__ import annotations +from contextlib import contextmanager from datetime import timedelta from typing import TYPE_CHECKING +from unittest.mock import patch import pytest -from sqlalchemy import select +from sqlalchemy import Delete, select from airflow._shared.timezones import timezone from airflow.configuration import conf @@ -237,21 +239,32 @@ def test_clear_with_all_map_indices_flag_wipes_wide( assert backend.get(scope0, "job_id", session=session) is None assert backend.get(scope1, "job_id", session=session) is None + def test_set_populates_expires_at( + self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun + ): + """set() always populates expires_at so cleanup has a single pass.""" + scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) + backend.set(scope, "job_id", "app_1234", session=session) + session.flush() + + row = session.scalar(select(TaskStateModel).where(TaskStateModel.key == "job_id")) + assert row is not None + assert row.expires_at is not None + def test_cleanup_removes_expired_rows( self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun ): scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) backend.set(scope, "old_key", "old_value", session=session) + backend.set(scope, "new_key", "new_value", session=session) session.flush() + # Backdate expires_at on old_key to simulate it having expired old_row = session.scalar( select(TaskStateModel).where(TaskStateModel.dag_id == DAG_ID, TaskStateModel.key == "old_key") ) assert old_row is not None - old_row.updated_at = timezone.utcnow() - timedelta(days=40) - session.flush() - - backend.set(scope, "new_key", "new_value", session=session) + old_row.expires_at = timezone.utcnow() - timedelta(hours=1) session.flush() session.commit() @@ -285,8 +298,17 @@ def test_cleanup_removes_expires_at_rows( @conf_vars({("state_store", "state_cleanup_batch_size"): "2"}) def test_cleanup_batches_deletes(self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun): - from unittest.mock import patch + """cleanup() issues one DELETE per batch, not one DELETE for all rows at once. + Verifying this is not straightforward because cleanup() creates its own internal session, + so we cannot simply inspect it from outside, so what we do is: + + 1. Patch `create_session` in the metastore module with a thin wrapper (`tracking_cs`) that + yields the real session but replaces `session.execute` with a spy. + 2. The spy checks whether the statement being executed is a sqla Delete object and + records it if so. + 3. After cleanup() returns, we assert that exactly ceil(/). + """ import airflow.state.metastore as metastore_mod scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) @@ -295,18 +317,33 @@ def test_cleanup_batches_deletes(self, session: Session, backend: MetastoreState session.flush() session.execute( - TaskStateModel.__table__.update().values(updated_at=timezone.utcnow() - timedelta(days=40)) + TaskStateModel.__table__.update().values(expires_at=timezone.utcnow() - timedelta(hours=1)) ) session.commit() - with patch.object(metastore_mod, "create_session", wraps=metastore_mod.create_session) as mock_cs: + deletes = [] + original_cs = metastore_mod.create_session + + @contextmanager + def tracking_cs(*args, **kwargs): + with original_cs(*args, **kwargs) as s: + orig_execute = s.execute + + def tracked(stmt, *a, **kw): + if isinstance(stmt, Delete): + deletes.append(stmt) + return orig_execute(stmt, *a, **kw) + + s.execute = tracked + yield s + + with patch.object(metastore_mod, "create_session", side_effect=tracking_cs): backend.cleanup() - # 5 rows, batch size = 2, so 3 batches and 1 check for expired_by row - assert mock_cs.call_count == 4 session.expire_all() - remaining = session.scalars(select(TaskStateModel).where(TaskStateModel.dag_id == DAG_ID)).all() - assert remaining == [] + + # batch_size=2, 5 rows -> delete runs 3 times (2+2+1) + assert len(deletes) == 3 class TestMetastoreStateBackendAssetScope: From f52ce27fccc392de6e070c5e6f60929f028152a0 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Mon, 11 May 2026 13:51:29 +0530 Subject: [PATCH 04/10] comment from TP --- .../src/airflow/jobs/scheduler_job_runner.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index eeabfc1427678..d0792554c273e 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -38,13 +38,11 @@ and_, case, delete, - delete as _delete, exists, func, inspect, or_, select, - select as _select, text, tuple_, update, @@ -73,12 +71,10 @@ from airflow.models import Deadline, Log from airflow.models.asset import ( AssetActive, - AssetActive as _AssetActive, AssetAliasModel, AssetDagRunQueue, AssetEvent, AssetModel, - AssetModel as _AssetModel, AssetPartitionDagRun, AssetWatcherModel, DagScheduleAssetAliasReference, @@ -3228,11 +3224,11 @@ def _cleanup_orphaned_asset_state(*, session: Session) -> None: asset_state rows become unreachable — no task can write to them anymore. This runs in the same pass as asset orphanage to keep the table clean. """ - active_asset_ids = _select(_AssetModel.id).join( - _AssetActive, - (_AssetActive.name == _AssetModel.name) & (_AssetActive.uri == _AssetModel.uri), + active_asset_ids = select(AssetModel.id).join( + AssetActive, + (AssetActive.name == AssetModel.name) & (AssetActive.uri == AssetModel.uri), ) - session.execute(_delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids))) + session.execute(delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids))) def _executor_to_workloads( self, From 7427d04ed2627f1c887e5e5a899bb5e5896f7877 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Mon, 11 May 2026 15:05:55 +0530 Subject: [PATCH 05/10] comment from wei --- .../src/airflow/cli/commands/state_store_command.py | 8 ++++---- airflow-core/src/airflow/config_templates/config.yml | 2 +- airflow-core/src/airflow/jobs/scheduler_job_runner.py | 2 +- airflow-core/tests/unit/state/test_metastore.py | 1 + 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/airflow-core/src/airflow/cli/commands/state_store_command.py b/airflow-core/src/airflow/cli/commands/state_store_command.py index 6aa5a83c200cd..f1475d23a3729 100644 --- a/airflow-core/src/airflow/cli/commands/state_store_command.py +++ b/airflow-core/src/airflow/cli/commands/state_store_command.py @@ -18,6 +18,9 @@ import logging +from airflow.state import get_state_backend +from airflow.state.metastore import MetastoreStateBackend + log = logging.getLogger(__name__) # Other state operations (list, get, delete per key) will be added here in the future. @@ -25,9 +28,6 @@ def cleanup(args) -> None: """Remove expired task state rows via the configured state backend.""" - from airflow.state import get_state_backend - from airflow.state.metastore import MetastoreStateBackend - backend = get_state_backend() if args.dry_run: @@ -40,7 +40,7 @@ def cleanup(args) -> None: print(f"Would delete {len(expired)} task state row(s):\n") for dag_id, run_id, task_id, map_index, key in expired: print( - f" DAG {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}" + f" Dag {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}" ) else: print("Custom backend configured — cannot preview rows.") diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 3f1b80fbf1d17..7ffed3cbdd540 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -3027,7 +3027,7 @@ state_store: default: "airflow.state.metastore.MetastoreStateBackend" default_retention_days: description: | - Number of days to retain task_state rows after their last update. + Number of days to retain task state after their last update. Rows older than this are removed by the scheduler's periodic cleanup. This config does not affect asset_state rows. Set to 0 to disable time-based cleanup entirely. diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index d0792554c273e..683c0add69668 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -3218,7 +3218,7 @@ def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]: @staticmethod def _cleanup_orphaned_asset_state(*, session: Session) -> None: """ - Delete asset_state rows for assets no longer active in any DAG. + Delete asset_state rows for assets no longer active in any Dag. When _orphan_unreferenced_assets removes an asset from asset_active, its asset_state rows become unreachable — no task can write to them anymore. diff --git a/airflow-core/tests/unit/state/test_metastore.py b/airflow-core/tests/unit/state/test_metastore.py index 2407f21d51bc4..a311b850b9e2b 100644 --- a/airflow-core/tests/unit/state/test_metastore.py +++ b/airflow-core/tests/unit/state/test_metastore.py @@ -250,6 +250,7 @@ def test_set_populates_expires_at( row = session.scalar(select(TaskStateModel).where(TaskStateModel.key == "job_id")) assert row is not None assert row.expires_at is not None + assert row.expires_at > row.updated_at def test_cleanup_removes_expired_rows( self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun From 151dee5367e6c42895cf2aac36c6c599c84bd2bc Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Mon, 11 May 2026 15:16:33 +0530 Subject: [PATCH 06/10] Update airflow-core/src/airflow/state/metastore.py Co-authored-by: Wei Lee --- airflow-core/src/airflow/state/metastore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/state/metastore.py b/airflow-core/src/airflow/state/metastore.py index cdf595d2be634..46fad8783f725 100644 --- a/airflow-core/src/airflow/state/metastore.py +++ b/airflow-core/src/airflow/state/metastore.py @@ -303,7 +303,7 @@ def _delete_batched(where_clause) -> int: deleted = _delete_batched(TaskStateModel.expires_at < now) log.info("Deleted expired task_state rows", rows_deleted=deleted) - def _dry_run_summary(self) -> dict[str, list]: + def _summary_dry_run_(self) -> dict[str, list]: """Return rows that would be deleted by cleanup() without deleting anything.""" now = timezone.utcnow() cols = ( From 66081d0b38652b5e02a197207050bdaacd032fa8 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 12 May 2026 11:26:55 +0530 Subject: [PATCH 07/10] fixing tests and static checks --- .../src/airflow/cli/commands/state_store_command.py | 2 +- .../tests/unit/cli/commands/test_state_store_command.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/cli/commands/state_store_command.py b/airflow-core/src/airflow/cli/commands/state_store_command.py index f1475d23a3729..3851386a87b1f 100644 --- a/airflow-core/src/airflow/cli/commands/state_store_command.py +++ b/airflow-core/src/airflow/cli/commands/state_store_command.py @@ -32,7 +32,7 @@ def cleanup(args) -> None: if args.dry_run: if isinstance(backend, MetastoreStateBackend): - summary = backend._dry_run_summary() + summary = backend._summary_dry_run_() expired = summary["expired"] if not expired: print("Nothing to delete.") diff --git a/airflow-core/tests/unit/cli/commands/test_state_store_command.py b/airflow-core/tests/unit/cli/commands/test_state_store_command.py index a6ad669181156..f74e7df92d43f 100644 --- a/airflow-core/tests/unit/cli/commands/test_state_store_command.py +++ b/airflow-core/tests/unit/cli/commands/test_state_store_command.py @@ -20,9 +20,13 @@ from unittest import mock from unittest.mock import MagicMock, patch +import pytest + from airflow.cli.commands.state_store_command import cleanup from airflow.state.metastore import MetastoreStateBackend +pytestmark = pytest.mark.db_test + class TestStateStoreCleanupCommand: def test_cleanup_calls_backend(self): @@ -40,7 +44,7 @@ def test_dry_run_does_not_call_backend(self, capsys): backend = MetastoreStateBackend() with ( mock.patch("airflow.state.get_state_backend", return_value=backend), - patch.object(backend, "_dry_run_summary", return_value={"expired": []}), + patch.object(backend, "_summary_dry_run_", return_value={"expired": []}), ): cleanup(args) From 58dba887d456890f4ca0a0f51ce88c619fd07d3e Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Tue, 12 May 2026 14:21:59 +0530 Subject: [PATCH 08/10] fixing tests --- .../tests/unit/cli/commands/test_state_store_command.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow-core/tests/unit/cli/commands/test_state_store_command.py b/airflow-core/tests/unit/cli/commands/test_state_store_command.py index f74e7df92d43f..715e4a4e47a21 100644 --- a/airflow-core/tests/unit/cli/commands/test_state_store_command.py +++ b/airflow-core/tests/unit/cli/commands/test_state_store_command.py @@ -31,7 +31,7 @@ class TestStateStoreCleanupCommand: def test_cleanup_calls_backend(self): args = Namespace(dry_run=False, verbose=False) - with mock.patch("airflow.state.get_state_backend") as mock_get_backend: + with mock.patch("airflow.cli.commands.state_store_command.get_state_backend") as mock_get_backend: mock_backend = MagicMock() mock_get_backend.return_value = mock_backend @@ -43,7 +43,7 @@ def test_dry_run_does_not_call_backend(self, capsys): args = Namespace(dry_run=True, verbose=False) backend = MetastoreStateBackend() with ( - mock.patch("airflow.state.get_state_backend", return_value=backend), + mock.patch("airflow.cli.commands.state_store_command.get_state_backend", return_value=backend), patch.object(backend, "_summary_dry_run_", return_value={"expired": []}), ): cleanup(args) From 6b6968e1893f5198edef93e6c3cc6e1978935e8e Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 13 May 2026 11:37:25 +0530 Subject: [PATCH 09/10] comment from wei --- airflow-core/src/airflow/cli/cli_config.py | 9 +++--- .../cli/commands/state_store_command.py | 31 +++++++++---------- .../cli/commands/test_state_store_command.py | 29 ++++++++++++----- 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/airflow-core/src/airflow/cli/cli_config.py b/airflow-core/src/airflow/cli/cli_config.py index d4d4611243bb0..f59f5ddc10aaf 100644 --- a/airflow-core/src/airflow/cli/cli_config.py +++ b/airflow-core/src/airflow/cli/cli_config.py @@ -1521,13 +1521,14 @@ class GroupCommand(NamedTuple): ) STATE_STORE_COMMANDS = ( ActionCommand( - name="cleanup", - help="Remove expired stored state via the configured state backend", + name="cleanup-task-states", + help="Remove expired task state rows (MetastoreStateBackend only)", description=( "Reads [state_store] default_retention_days from config and deletes task_state rows " - "older than the configured threshold. Use --dry-run to preview without deleting." + "older than the configured threshold. Only applies when MetastoreStateBackend is configured; " + "custom backends are skipped. Use --dry-run to preview without deleting." ), - func=lazy_load_command("airflow.cli.commands.state_store_command.cleanup"), + func=lazy_load_command("airflow.cli.commands.state_store_command.cleanup_task_states"), args=(ARG_DB_DRY_RUN, ARG_VERBOSE), ), ) diff --git a/airflow-core/src/airflow/cli/commands/state_store_command.py b/airflow-core/src/airflow/cli/commands/state_store_command.py index 3851386a87b1f..62a0203923f0b 100644 --- a/airflow-core/src/airflow/cli/commands/state_store_command.py +++ b/airflow-core/src/airflow/cli/commands/state_store_command.py @@ -26,25 +26,24 @@ # Other state operations (list, get, delete per key) will be added here in the future. -def cleanup(args) -> None: - """Remove expired task state rows via the configured state backend.""" +def cleanup_task_states(args) -> None: + """Remove expired task state rows (MetastoreStateBackend only).""" backend = get_state_backend() + if not isinstance(backend, MetastoreStateBackend): + print("Custom backend configured — skipping cleanup (not supported).") + return + if args.dry_run: - if isinstance(backend, MetastoreStateBackend): - summary = backend._summary_dry_run_() - expired = summary["expired"] - if not expired: - print("Nothing to delete.") - return - print(f"Would delete {len(expired)} task state row(s):\n") - for dag_id, run_id, task_id, map_index, key in expired: - print( - f" Dag {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}" - ) - else: - print("Custom backend configured — cannot preview rows.") + summary = backend._summary_dry_run_() + expired = summary["expired"] + if not expired: + print("Nothing to delete.") + return + print(f"Would delete {len(expired)} task state row(s):\n") + for dag_id, run_id, task_id, map_index, key in expired: + print(f" Dag {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}") return - log.info("Running state store cleanup") + log.info("Running task state cleanup") backend.cleanup() diff --git a/airflow-core/tests/unit/cli/commands/test_state_store_command.py b/airflow-core/tests/unit/cli/commands/test_state_store_command.py index 715e4a4e47a21..ed877568afa3b 100644 --- a/airflow-core/tests/unit/cli/commands/test_state_store_command.py +++ b/airflow-core/tests/unit/cli/commands/test_state_store_command.py @@ -22,7 +22,7 @@ import pytest -from airflow.cli.commands.state_store_command import cleanup +from airflow.cli.commands.state_store_command import cleanup_task_states from airflow.state.metastore import MetastoreStateBackend pytestmark = pytest.mark.db_test @@ -31,13 +31,14 @@ class TestStateStoreCleanupCommand: def test_cleanup_calls_backend(self): args = Namespace(dry_run=False, verbose=False) - with mock.patch("airflow.cli.commands.state_store_command.get_state_backend") as mock_get_backend: - mock_backend = MagicMock() - mock_get_backend.return_value = mock_backend - - cleanup(args) + backend = MetastoreStateBackend() + with ( + mock.patch("airflow.cli.commands.state_store_command.get_state_backend", return_value=backend), + patch.object(backend, "cleanup"), + ): + cleanup_task_states(args) - mock_backend.cleanup.assert_called_once_with() + backend.cleanup.assert_called_once_with() def test_dry_run_does_not_call_backend(self, capsys): args = Namespace(dry_run=True, verbose=False) @@ -46,7 +47,19 @@ def test_dry_run_does_not_call_backend(self, capsys): mock.patch("airflow.cli.commands.state_store_command.get_state_backend", return_value=backend), patch.object(backend, "_summary_dry_run_", return_value={"expired": []}), ): - cleanup(args) + cleanup_task_states(args) captured = capsys.readouterr() assert "Nothing to delete" in captured.out + + def test_custom_backend_is_skipped(self, capsys): + args = Namespace(dry_run=False, verbose=False) + custom_backend = MagicMock(spec=[]) + with mock.patch( + "airflow.cli.commands.state_store_command.get_state_backend", return_value=custom_backend + ): + cleanup_task_states(args) + + captured = capsys.readouterr() + assert "Custom backend configured" in captured.out + assert not hasattr(custom_backend, "cleanup") or not custom_backend.cleanup.called From 77450e580c1b64447e1af84f26dcfb09f099cc4e Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 13 May 2026 17:44:06 +0530 Subject: [PATCH 10/10] last set of comments from wei --- airflow-core/src/airflow/cli/commands/state_store_command.py | 2 +- airflow-core/src/airflow/config_templates/config.yml | 2 +- airflow-core/src/airflow/state/metastore.py | 2 +- .../tests/unit/cli/commands/test_state_store_command.py | 2 +- shared/state/src/airflow_shared/state/__init__.py | 2 -- 5 files changed, 4 insertions(+), 6 deletions(-) diff --git a/airflow-core/src/airflow/cli/commands/state_store_command.py b/airflow-core/src/airflow/cli/commands/state_store_command.py index 62a0203923f0b..52bd095256123 100644 --- a/airflow-core/src/airflow/cli/commands/state_store_command.py +++ b/airflow-core/src/airflow/cli/commands/state_store_command.py @@ -35,7 +35,7 @@ def cleanup_task_states(args) -> None: return if args.dry_run: - summary = backend._summary_dry_run_() + summary = backend._summary_dry_run() expired = summary["expired"] if not expired: print("Nothing to delete.") diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 7ffed3cbdd540..dfe8162f2c603 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -3028,7 +3028,7 @@ state_store: default_retention_days: description: | Number of days to retain task state after their last update. - Rows older than this are removed by the scheduler's periodic cleanup. + Rows older than this are removed when cleanup is triggered. This config does not affect asset_state rows. Set to 0 to disable time-based cleanup entirely. version_added: 3.3.0 diff --git a/airflow-core/src/airflow/state/metastore.py b/airflow-core/src/airflow/state/metastore.py index 46fad8783f725..17b9738734144 100644 --- a/airflow-core/src/airflow/state/metastore.py +++ b/airflow-core/src/airflow/state/metastore.py @@ -303,7 +303,7 @@ def _delete_batched(where_clause) -> int: deleted = _delete_batched(TaskStateModel.expires_at < now) log.info("Deleted expired task_state rows", rows_deleted=deleted) - def _summary_dry_run_(self) -> dict[str, list]: + def _summary_dry_run(self) -> dict[str, list]: """Return rows that would be deleted by cleanup() without deleting anything.""" now = timezone.utcnow() cols = ( diff --git a/airflow-core/tests/unit/cli/commands/test_state_store_command.py b/airflow-core/tests/unit/cli/commands/test_state_store_command.py index ed877568afa3b..e4b44eee13f10 100644 --- a/airflow-core/tests/unit/cli/commands/test_state_store_command.py +++ b/airflow-core/tests/unit/cli/commands/test_state_store_command.py @@ -45,7 +45,7 @@ def test_dry_run_does_not_call_backend(self, capsys): backend = MetastoreStateBackend() with ( mock.patch("airflow.cli.commands.state_store_command.get_state_backend", return_value=backend), - patch.object(backend, "_summary_dry_run_", return_value={"expired": []}), + patch.object(backend, "_summary_dry_run", return_value={"expired": []}), ): cleanup_task_states(args) diff --git a/shared/state/src/airflow_shared/state/__init__.py b/shared/state/src/airflow_shared/state/__init__.py index 1e03a381957d5..0560891329b29 100644 --- a/shared/state/src/airflow_shared/state/__init__.py +++ b/shared/state/src/airflow_shared/state/__init__.py @@ -130,6 +130,4 @@ def cleanup(self) -> None: This is a no-op by default. Custom backends override this to implement their own retention policy. The backend is responsible for reading any relevant config (e.g. ``[state_store] default_retention_days``) and deciding what to delete. - Airflow does not call this from any standard job — the scheduler triggers it via - ``call_regular_interval`` for the default backend. """