diff --git a/airflow-core/src/airflow/cli/cli_config.py b/airflow-core/src/airflow/cli/cli_config.py index 4c44ab39d675f..81b9dcf060058 100644 --- a/airflow-core/src/airflow/cli/cli_config.py +++ b/airflow-core/src/airflow/cli/cli_config.py @@ -1531,6 +1531,20 @@ class GroupCommand(NamedTuple): args=(ARG_VERBOSE,), ), ) +STATE_STORE_COMMANDS = ( + ActionCommand( + name="cleanup-task-states", + help="Remove expired task state rows (MetastoreStateBackend only)", + description=( + "Reads [state_store] default_retention_days from config and deletes task_state rows " + "older than the configured threshold. Only applies when MetastoreStateBackend is configured; " + "custom backends are skipped. Use --dry-run to preview without deleting." + ), + func=lazy_load_command("airflow.cli.commands.state_store_command.cleanup_task_states"), + args=(ARG_DB_DRY_RUN, ARG_VERBOSE), + ), +) + DB_COMMANDS = ( ActionCommand( name="check-migrations", @@ -2115,6 +2129,11 @@ class GroupCommand(NamedTuple): help="Display providers", subcommands=PROVIDERS_COMMANDS, ), + GroupCommand( + name="state-store", + help="Manage task and asset state storage", + subcommands=STATE_STORE_COMMANDS, + ), ActionCommand( name="rotate-fernet-key", func=lazy_load_command("airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key"), diff --git a/airflow-core/src/airflow/cli/commands/state_store_command.py b/airflow-core/src/airflow/cli/commands/state_store_command.py new file mode 100644 index 0000000000000..52bd095256123 --- /dev/null +++ b/airflow-core/src/airflow/cli/commands/state_store_command.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging + +from airflow.state import get_state_backend +from airflow.state.metastore import MetastoreStateBackend + +log = logging.getLogger(__name__) + +# Other state operations (list, get, delete per key) will be added here in the future. + + +def cleanup_task_states(args) -> None: + """Remove expired task state rows (MetastoreStateBackend only).""" + backend = get_state_backend() + + if not isinstance(backend, MetastoreStateBackend): + print("Custom backend configured — skipping cleanup (not supported).") + return + + if args.dry_run: + summary = backend._summary_dry_run() + expired = summary["expired"] + if not expired: + print("Nothing to delete.") + return + print(f"Would delete {len(expired)} task state row(s):\n") + for dag_id, run_id, task_id, map_index, key in expired: + print(f" Dag {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}") + return + + log.info("Running task state cleanup") + backend.cleanup() diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 8d5d6e5fd2611..4b183f9c2b40f 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -3025,6 +3025,24 @@ state_store: type: string example: "mypackage.state.CustomStateBackend" default: "airflow.state.metastore.MetastoreStateBackend" + default_retention_days: + description: | + Number of days to retain task state after their last update. + Rows older than this are removed when cleanup is triggered. + This config does not affect asset_state rows. + Set to 0 to disable time-based cleanup entirely. + version_added: 3.3.0 + type: integer + example: "7" + default: "30" + state_cleanup_batch_size: + description: | + Number of rows deleted per batch during cleanup. Defaults to 0 (no batching). + Tune this on deployments with large task_state tables to improve performance per transaction. + version_added: 3.3.0 + type: integer + example: "10000" + default: "0" profiling: description: | diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 1a3f55b7f6f3d..9a650b110c963 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -33,7 +33,20 @@ from itertools import groupby from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import CTE, and_, case, delete, exists, func, inspect, or_, select, text, tuple_, update +from sqlalchemy import ( + CTE, + and_, + case, + delete, + exists, + func, + inspect, + or_, + select, + text, + tuple_, + update, +) from sqlalchemy.exc import DBAPIError, OperationalError from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, selectinload from sqlalchemy.sql import expression @@ -70,6 +83,7 @@ TaskInletAssetReference, TaskOutletAssetReference, ) +from airflow.models.asset_state import AssetStateModel from airflow.models.backfill import Backfill, BackfillDagRun from airflow.models.callback import Callback, CallbackType, ExecutorCallback from airflow.models.dag import DagModel @@ -3096,6 +3110,7 @@ def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None: self._orphan_unreferenced_assets(orphan_query, session=session) self._activate_referenced_assets(activate_query, session=session) + self._cleanup_orphaned_asset_state(session=session) @staticmethod def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) -> None: @@ -3204,6 +3219,21 @@ def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]: session.add(warning) existing_warned_dag_ids.add(warning.dag_id) + @staticmethod + def _cleanup_orphaned_asset_state(*, session: Session) -> None: + """ + Delete asset_state rows for assets no longer active in any Dag. + + When _orphan_unreferenced_assets removes an asset from asset_active, its + asset_state rows become unreachable — no task can write to them anymore. + This runs in the same pass as asset orphanage to keep the table clean. + """ + active_asset_ids = select(AssetModel.id).join( + AssetActive, + (AssetActive.name == AssetModel.name) & (AssetActive.uri == AssetModel.uri), + ) + session.execute(delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids))) + def _executor_to_workloads( self, workloads: Iterable[SchedulerWorkload], diff --git a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py index 7f852d05c6ca6..e64f80a05b119 100644 --- a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py +++ b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py @@ -57,6 +57,7 @@ def upgrade(): ) op.create_table( "task_state", + sa.Column("id", sa.Integer(), nullable=False, autoincrement=True), sa.Column("dag_run_id", sa.Integer(), nullable=False), sa.Column("task_id", StringID(), nullable=False), sa.Column("map_index", sa.Integer(), server_default="-1", nullable=False), @@ -65,20 +66,24 @@ def upgrade(): sa.Column("run_id", StringID(), nullable=False), sa.Column("value", sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"), nullable=False), sa.Column("updated_at", UtcDateTime(), nullable=False), + sa.Column("expires_at", UtcDateTime(), nullable=True), sa.ForeignKeyConstraint( ["dag_run_id"], ["dag_run.id"], name="task_state_dag_run_fkey", ondelete="CASCADE" ), - sa.PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"), + sa.PrimaryKeyConstraint("id", name="task_state_pkey"), + sa.UniqueConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_uq"), ) with op.batch_alter_table("task_state", schema=None) as batch_op: batch_op.create_index( "idx_task_state_lookup", ["dag_id", "run_id", "task_id", "map_index"], unique=False ) + batch_op.create_index("idx_task_state_expires_at", ["expires_at"], unique=False) def downgrade(): """Unapply add task_state and asset_state tables.""" with op.batch_alter_table("task_state", schema=None) as batch_op: + batch_op.drop_index("idx_task_state_expires_at") batch_op.drop_index("idx_task_state_lookup") op.drop_table("task_state") diff --git a/airflow-core/src/airflow/models/task_state.py b/airflow-core/src/airflow/models/task_state.py index dbc17e3b06950..72a7624eddd6e 100644 --- a/airflow-core/src/airflow/models/task_state.py +++ b/airflow-core/src/airflow/models/task_state.py @@ -19,7 +19,7 @@ from datetime import datetime -from sqlalchemy import ForeignKeyConstraint, Index, Integer, PrimaryKeyConstraint, String, Text +from sqlalchemy import ForeignKeyConstraint, Index, Integer, String, Text, UniqueConstraint from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.orm import Mapped, mapped_column @@ -39,19 +39,27 @@ class TaskStateModel(Base): __tablename__ = "task_state" - dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False, primary_key=True) - task_id: Mapped[str] = mapped_column(StringID(), nullable=False, primary_key=True) - map_index: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, server_default="-1") - key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + + dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False) + task_id: Mapped[str] = mapped_column(StringID(), nullable=False) + map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default="-1") + key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False) dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) run_id: Mapped[str] = mapped_column(StringID(), nullable=False) value: Mapped[str] = mapped_column(Text().with_variant(MEDIUMTEXT, "mysql"), nullable=False) updated_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + # Optional override for early expiry. When set, garbage collection deletes this row when + # expires_at < now(), even if updated_at is recent. NULL means no early expiry — + # the row is still cleaned up by the global `updated_at + default_retention_days` check. + # Populated via task_state.set(retention_days=N) for keys that should expire differently + # than the deployment wide default. + expires_at: Mapped[datetime | None] = mapped_column(UtcDateTime, nullable=True) __table_args__ = ( - PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"), + UniqueConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_uq"), ForeignKeyConstraint( ["dag_run_id"], ["dag_run.id"], @@ -59,4 +67,5 @@ class TaskStateModel(Base): ondelete="CASCADE", ), Index("idx_task_state_lookup", "dag_id", "run_id", "task_id", "map_index"), + Index("idx_task_state_expires_at", "expires_at"), ) diff --git a/airflow-core/src/airflow/state/metastore.py b/airflow-core/src/airflow/state/metastore.py index 31b4de3158fb4..f58c69f5808b3 100644 --- a/airflow-core/src/airflow/state/metastore.py +++ b/airflow-core/src/airflow/state/metastore.py @@ -19,17 +19,20 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from datetime import datetime, timedelta from typing import TYPE_CHECKING +import structlog from sqlalchemy import delete, select from airflow._shared.state import AssetScope, BaseStateBackend, StateScope, TaskScope from airflow._shared.timezones import timezone +from airflow.configuration import conf from airflow.models.asset_state import AssetStateModel from airflow.models.dagrun import DagRun from airflow.models.task_state import TaskStateModel from airflow.typing_compat import assert_never -from airflow.utils.session import NEW_SESSION, create_session_async, provide_session +from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session from airflow.utils.sqlalchemy import get_dialect_name if TYPE_CHECKING: @@ -40,6 +43,21 @@ from sqlalchemy.orm import Session +log = structlog.get_logger(__name__) + + +def _compute_expires_at(now: datetime) -> datetime | None: + """ + Return the expiry timestamp for a new task state row based on config. + + Returns None if default_retention_days is 0 (never expires). + """ + retention_days = conf.getint("state_store", "default_retention_days") + if retention_days <= 0: + return None + return now + timedelta(days=retention_days) + + @asynccontextmanager async def _async_session(session: AsyncSession | None) -> AsyncGenerator[AsyncSession, None]: """Use provided async session or create a new one.""" @@ -200,6 +218,7 @@ def _set_task_state(self, scope: TaskScope, key: str, value: str, *, session: Se if dag_run_id is None: raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} run_id={scope.run_id!r}") now = timezone.utcnow() + expires_at = _compute_expires_at(now) values = dict( dag_run_id=dag_run_id, dag_id=scope.dag_id, @@ -209,13 +228,14 @@ def _set_task_state(self, scope: TaskScope, key: str, value: str, *, session: Se key=key, value=value, updated_at=now, + expires_at=expires_at, ) stmt = _build_upsert_stmt( get_dialect_name(session), TaskStateModel, ["dag_run_id", "task_id", "map_index", "key"], values, - dict(value=value, updated_at=now), + dict(value=value, updated_at=now, expires_at=expires_at), ) session.execute(stmt) @@ -276,6 +296,51 @@ def _clear_asset_state(self, scope: AssetScope, *, session: Session) -> None: ) ) + def cleanup(self) -> None: + """ + Remove expired task state rows. + + ``expires_at`` is set at write time on every ``set()`` call, so cleanup is a single + ``WHERE expires_at < now()`` pass. Rows with ``expires_at=NULL`` (default_retention_days=0) + are never deleted. Batching is configurable via ``[state_store] state_cleanup_batch_size``. + """ + batch_size = conf.getint("state_store", "state_cleanup_batch_size") + now = timezone.utcnow() + + def _delete_batched(where_clause) -> int: + total = 0 + with create_session() as session: + while True: + id_query = select(TaskStateModel.id).where(where_clause) + if batch_size > 0: + id_query = id_query.limit(batch_size) + ids = session.scalars(id_query).all() + if not ids: + break + session.execute(delete(TaskStateModel).where(TaskStateModel.id.in_(ids))) + session.commit() + total += len(ids) + if batch_size <= 0 or len(ids) < batch_size: + break + return total + + deleted = _delete_batched(TaskStateModel.expires_at < now) + log.info("Deleted expired task_state rows", rows_deleted=deleted) + + def _summary_dry_run(self) -> dict[str, list]: + """Return rows that would be deleted by cleanup() without deleting anything.""" + now = timezone.utcnow() + cols = ( + TaskStateModel.dag_id, + TaskStateModel.run_id, + TaskStateModel.task_id, + TaskStateModel.map_index, + TaskStateModel.key, + ) + with create_session() as session: + expired = session.execute(select(*cols).where(TaskStateModel.expires_at < now)).all() + return {"expired": list(expired)} + async def _aget_task_state(self, scope: TaskScope, key: str, *, session: AsyncSession) -> str | None: row = await session.scalar( select(TaskStateModel).where( @@ -300,6 +365,7 @@ async def _aset_task_state( if dag_run_id is None: raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} run_id={scope.run_id!r}") now = timezone.utcnow() + expires_at = _compute_expires_at(now) values = dict( dag_run_id=dag_run_id, dag_id=scope.dag_id, @@ -309,6 +375,7 @@ async def _aset_task_state( key=key, value=value, updated_at=now, + expires_at=expires_at, ) # get_dialect_name expects a sync Session; sync_session is the underlying Session the async wrapper delegates to stmt = _build_upsert_stmt( @@ -316,7 +383,7 @@ async def _aset_task_state( TaskStateModel, ["dag_run_id", "task_id", "map_index", "key"], values, - dict(value=value, updated_at=now), + dict(value=value, updated_at=now, expires_at=expires_at), ) await session.execute(stmt) diff --git a/airflow-core/tests/unit/cli/commands/test_state_store_command.py b/airflow-core/tests/unit/cli/commands/test_state_store_command.py new file mode 100644 index 0000000000000..e4b44eee13f10 --- /dev/null +++ b/airflow-core/tests/unit/cli/commands/test_state_store_command.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from argparse import Namespace +from unittest import mock +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.cli.commands.state_store_command import cleanup_task_states +from airflow.state.metastore import MetastoreStateBackend + +pytestmark = pytest.mark.db_test + + +class TestStateStoreCleanupCommand: + def test_cleanup_calls_backend(self): + args = Namespace(dry_run=False, verbose=False) + backend = MetastoreStateBackend() + with ( + mock.patch("airflow.cli.commands.state_store_command.get_state_backend", return_value=backend), + patch.object(backend, "cleanup"), + ): + cleanup_task_states(args) + + backend.cleanup.assert_called_once_with() + + def test_dry_run_does_not_call_backend(self, capsys): + args = Namespace(dry_run=True, verbose=False) + backend = MetastoreStateBackend() + with ( + mock.patch("airflow.cli.commands.state_store_command.get_state_backend", return_value=backend), + patch.object(backend, "_summary_dry_run", return_value={"expired": []}), + ): + cleanup_task_states(args) + + captured = capsys.readouterr() + assert "Nothing to delete" in captured.out + + def test_custom_backend_is_skipped(self, capsys): + args = Namespace(dry_run=False, verbose=False) + custom_backend = MagicMock(spec=[]) + with mock.patch( + "airflow.cli.commands.state_store_command.get_state_backend", return_value=custom_backend + ): + cleanup_task_states(args) + + captured = capsys.readouterr() + assert "Custom backend configured" in captured.out + assert not hasattr(custom_backend, "cleanup") or not custom_backend.cleanup.called diff --git a/airflow-core/tests/unit/state/test_metastore.py b/airflow-core/tests/unit/state/test_metastore.py index dfd154cc92ae5..d9e1ff33afd74 100644 --- a/airflow-core/tests/unit/state/test_metastore.py +++ b/airflow-core/tests/unit/state/test_metastore.py @@ -17,13 +17,18 @@ # under the License. from __future__ import annotations +from contextlib import contextmanager +from datetime import timedelta from typing import TYPE_CHECKING +from unittest.mock import patch import pytest -from sqlalchemy import select +from sqlalchemy import Delete, select from airflow._shared.timezones import timezone +from airflow.configuration import conf from airflow.models.asset import AssetModel +from airflow.models.asset_state import AssetStateModel from airflow.models.dagrun import DagRun, DagRunType from airflow.models.task_state import TaskStateModel from airflow.state import AssetScope, TaskScope, resolve_state_backend @@ -234,6 +239,113 @@ def test_clear_with_all_map_indices_flag_wipes_wide( assert backend.get(scope0, "job_id", session=session) is None assert backend.get(scope1, "job_id", session=session) is None + def test_set_populates_expires_at( + self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun + ): + """set() always populates expires_at so cleanup has a single pass.""" + scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) + backend.set(scope, "job_id", "app_1234", session=session) + session.flush() + + row = session.scalar(select(TaskStateModel).where(TaskStateModel.key == "job_id")) + assert row is not None + assert row.expires_at is not None + assert row.expires_at > row.updated_at + + def test_cleanup_removes_expired_rows( + self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun + ): + scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) + backend.set(scope, "old_key", "old_value", session=session) + backend.set(scope, "new_key", "new_value", session=session) + session.flush() + + # Backdate expires_at on old_key to simulate it having expired + old_row = session.scalar( + select(TaskStateModel).where(TaskStateModel.dag_id == DAG_ID, TaskStateModel.key == "old_key") + ) + assert old_row is not None + old_row.expires_at = timezone.utcnow() - timedelta(hours=1) + session.flush() + session.commit() + + backend.cleanup() + + session.expire_all() + assert session.scalar(select(TaskStateModel).where(TaskStateModel.key == "old_key")) is None + assert session.scalar(select(TaskStateModel).where(TaskStateModel.key == "new_key")) is not None + + def test_cleanup_removes_expires_at_rows( + self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun + ): + scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) + backend.set(scope, "short_lived", "value", session=session) + session.flush() + + row = session.scalar( + select(TaskStateModel).where(TaskStateModel.dag_id == DAG_ID, TaskStateModel.key == "short_lived") + ) + assert row is not None + row.expires_at = timezone.utcnow() - timedelta(hours=1) + session.flush() + session.commit() + + backend.cleanup() + + session.expire_all() + + # cleaned up via expires_at, even though updated_at is recent + assert session.scalar(select(TaskStateModel).where(TaskStateModel.key == "short_lived")) is None + + @conf_vars({("state_store", "state_cleanup_batch_size"): "2"}) + def test_cleanup_batches_deletes(self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun): + """cleanup() issues one DELETE per batch, not one DELETE for all rows at once. + + Verifying this is not straightforward because cleanup() creates its own internal session, + so we cannot simply inspect it from outside, so what we do is: + + 1. Patch `create_session` in the metastore module with a thin wrapper (`tracking_cs`) that + yields the real session but replaces `session.execute` with a spy. + 2. The spy checks whether the statement being executed is a sqla Delete object and + records it if so. + 3. After cleanup() returns, we assert that exactly ceil(/). + """ + import airflow.state.metastore as metastore_mod + + scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) + for key in ("k1", "k2", "k3", "k4", "k5"): + backend.set(scope, key, "v", session=session) + session.flush() + + session.execute( + TaskStateModel.__table__.update().values(expires_at=timezone.utcnow() - timedelta(hours=1)) + ) + session.commit() + + deletes = [] + original_cs = metastore_mod.create_session + + @contextmanager + def tracking_cs(*args, **kwargs): + with original_cs(*args, **kwargs) as s: + orig_execute = s.execute + + def tracked(stmt, *a, **kw): + if isinstance(stmt, Delete): + deletes.append(stmt) + return orig_execute(stmt, *a, **kw) + + s.execute = tracked + yield s + + with patch.object(metastore_mod, "create_session", side_effect=tracking_cs): + backend.cleanup() + + session.expire_all() + + # batch_size=2, 5 rows -> delete runs 3 times (2+2+1) + assert len(deletes) == 3 + class TestMetastoreStateBackendAssetScope: def test_get_returns_none_for_missing_key( @@ -306,6 +418,19 @@ def test_different_assets_are_isolated( assert backend.get(scope2, "watermark", session=session) is None + def test_cleanup_does_not_touch_asset_state( + self, session: Session, backend: MetastoreStateBackend, asset: AssetModel + ): + scope = AssetScope(asset_id=asset.id) + backend.set(scope, "watermark", "2026-01-01", session=session) + session.flush() + session.commit() + + backend.cleanup() + + session.expire_all() + assert session.scalar(select(AssetStateModel).where(AssetStateModel.asset_id == asset.id)) is not None + @pytest.mark.asyncio(loop_scope="class") class TestMetastoreStateBackendAsync: @@ -390,6 +515,19 @@ async def test_aset_and_aget_with_provided_session( assert result == "app_with_session" +class TestStateStoreConfig: + def test_defaults(self): + assert conf.getint("state_store", "default_retention_days") == 30 + assert conf.getint("state_store", "state_cleanup_batch_size") == 0 + + @conf_vars( + {("state_store", "default_retention_days"): "7", ("state_store", "state_cleanup_batch_size"): "50"} + ) + def test_overrides(self): + assert conf.getint("state_store", "default_retention_days") == 7 + assert conf.getint("state_store", "state_cleanup_batch_size") == 50 + + class TestResolveStateBackend: @conf_vars({("state_store", "backend"): "airflow.state.metastore.MetastoreStateBackend"}) def test_resolve_returns_configured_backend(self): diff --git a/shared/state/src/airflow_shared/state/__init__.py b/shared/state/src/airflow_shared/state/__init__.py index 4920f66ae6764..e231bdfd3bd8c 100644 --- a/shared/state/src/airflow_shared/state/__init__.py +++ b/shared/state/src/airflow_shared/state/__init__.py @@ -157,3 +157,12 @@ async def aclear( ``session`` is optional. If provided, implementations should use it directly. If ``None``, implementations manage their own async session internally. """ + + def cleanup(self) -> None: + """ + Remove expired and orphaned state records. + + This is a no-op by default. Custom backends override this to implement their own + retention policy. The backend is responsible for reading any relevant config (e.g. + ``[state_store] default_retention_days``) and deciding what to delete. + """