Skip to content
Merged
19 changes: 19 additions & 0 deletions airflow-core/src/airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,20 @@ class GroupCommand(NamedTuple):
args=(ARG_VERBOSE,),
),
)
STATE_STORE_COMMANDS = (
ActionCommand(
name="cleanup-task-states",
help="Remove expired task state rows (MetastoreStateBackend only)",
description=(
"Reads [state_store] default_retention_days from config and deletes task_state rows "
"older than the configured threshold. Only applies when MetastoreStateBackend is configured; "
"custom backends are skipped. Use --dry-run to preview without deleting."
),
func=lazy_load_command("airflow.cli.commands.state_store_command.cleanup_task_states"),
args=(ARG_DB_DRY_RUN, ARG_VERBOSE),
),
)

DB_COMMANDS = (
ActionCommand(
name="check-migrations",
Expand Down Expand Up @@ -2115,6 +2129,11 @@ class GroupCommand(NamedTuple):
help="Display providers",
subcommands=PROVIDERS_COMMANDS,
),
GroupCommand(
name="state-store",
help="Manage task and asset state storage",
Comment thread
Lee-W marked this conversation as resolved.
subcommands=STATE_STORE_COMMANDS,
),
ActionCommand(
name="rotate-fernet-key",
func=lazy_load_command("airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key"),
Expand Down
49 changes: 49 additions & 0 deletions airflow-core/src/airflow/cli/commands/state_store_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging

from airflow.state import get_state_backend
from airflow.state.metastore import MetastoreStateBackend

log = logging.getLogger(__name__)

# Other state operations (list, get, delete per key) will be added here in the future.


def cleanup_task_states(args) -> None:
"""Remove expired task state rows (MetastoreStateBackend only)."""
backend = get_state_backend()

if not isinstance(backend, MetastoreStateBackend):
print("Custom backend configured — skipping cleanup (not supported).")
return

if args.dry_run:
summary = backend._summary_dry_run()
expired = summary["expired"]
if not expired:
print("Nothing to delete.")
return
print(f"Would delete {len(expired)} task state row(s):\n")
for dag_id, run_id, task_id, map_index, key in expired:
print(f" Dag {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}")
return

log.info("Running task state cleanup")
backend.cleanup()
18 changes: 18 additions & 0 deletions airflow-core/src/airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3025,6 +3025,24 @@ state_store:
type: string
example: "mypackage.state.CustomStateBackend"
default: "airflow.state.metastore.MetastoreStateBackend"
default_retention_days:
description: |
Number of days to retain task state after their last update.
Rows older than this are removed when cleanup is triggered.
This config does not affect asset_state rows.
Comment thread
amoghrajesh marked this conversation as resolved.
Set to 0 to disable time-based cleanup entirely.
version_added: 3.3.0
type: integer
example: "7"
default: "30"
state_cleanup_batch_size:
description: |
Number of rows deleted per batch during cleanup. Defaults to 0 (no batching).
Tune this on deployments with large task_state tables to improve performance per transaction.
version_added: 3.3.0
type: integer
example: "10000"
default: "0"

profiling:
description: |
Expand Down
32 changes: 31 additions & 1 deletion airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,20 @@
from itertools import groupby
from typing import TYPE_CHECKING, Any, cast

from sqlalchemy import CTE, and_, case, delete, exists, func, inspect, or_, select, text, tuple_, update
from sqlalchemy import (
CTE,
and_,
case,
delete,
exists,
func,
inspect,
or_,
select,
text,
tuple_,
update,
)
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, selectinload
from sqlalchemy.sql import expression
Expand Down Expand Up @@ -70,6 +83,7 @@
TaskInletAssetReference,
TaskOutletAssetReference,
)
from airflow.models.asset_state import AssetStateModel
from airflow.models.backfill import Backfill, BackfillDagRun
from airflow.models.callback import Callback, CallbackType, ExecutorCallback
from airflow.models.dag import DagModel
Expand Down Expand Up @@ -3096,6 +3110,7 @@ def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None:

self._orphan_unreferenced_assets(orphan_query, session=session)
self._activate_referenced_assets(activate_query, session=session)
self._cleanup_orphaned_asset_state(session=session)

@staticmethod
def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) -> None:
Expand Down Expand Up @@ -3204,6 +3219,21 @@ def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]:
session.add(warning)
existing_warned_dag_ids.add(warning.dag_id)

@staticmethod
def _cleanup_orphaned_asset_state(*, session: Session) -> None:
"""
Delete asset_state rows for assets no longer active in any Dag.

When _orphan_unreferenced_assets removes an asset from asset_active, its
asset_state rows become unreachable — no task can write to them anymore.
This runs in the same pass as asset orphanage to keep the table clean.
"""
active_asset_ids = select(AssetModel.id).join(
AssetActive,
(AssetActive.name == AssetModel.name) & (AssetActive.uri == AssetModel.uri),
)
session.execute(delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids)))

def _executor_to_workloads(
self,
workloads: Iterable[SchedulerWorkload],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def upgrade():
)
op.create_table(
"task_state",
sa.Column("id", sa.Integer(), nullable=False, autoincrement=True),
sa.Column("dag_run_id", sa.Integer(), nullable=False),
sa.Column("task_id", StringID(), nullable=False),
sa.Column("map_index", sa.Integer(), server_default="-1", nullable=False),
Expand All @@ -65,20 +66,24 @@ def upgrade():
sa.Column("run_id", StringID(), nullable=False),
sa.Column("value", sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"), nullable=False),
sa.Column("updated_at", UtcDateTime(), nullable=False),
sa.Column("expires_at", UtcDateTime(), nullable=True),
sa.ForeignKeyConstraint(
["dag_run_id"], ["dag_run.id"], name="task_state_dag_run_fkey", ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"),
sa.PrimaryKeyConstraint("id", name="task_state_pkey"),
sa.UniqueConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_uq"),
)
with op.batch_alter_table("task_state", schema=None) as batch_op:
batch_op.create_index(
"idx_task_state_lookup", ["dag_id", "run_id", "task_id", "map_index"], unique=False
)
batch_op.create_index("idx_task_state_expires_at", ["expires_at"], unique=False)


def downgrade():
"""Unapply add task_state and asset_state tables."""
with op.batch_alter_table("task_state", schema=None) as batch_op:
batch_op.drop_index("idx_task_state_expires_at")
batch_op.drop_index("idx_task_state_lookup")

op.drop_table("task_state")
Expand Down
21 changes: 15 additions & 6 deletions airflow-core/src/airflow/models/task_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from datetime import datetime

from sqlalchemy import ForeignKeyConstraint, Index, Integer, PrimaryKeyConstraint, String, Text
from sqlalchemy import ForeignKeyConstraint, Index, Integer, String, Text, UniqueConstraint
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.orm import Mapped, mapped_column

Expand All @@ -39,24 +39,33 @@ class TaskStateModel(Base):

__tablename__ = "task_state"

dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False, primary_key=True)
task_id: Mapped[str] = mapped_column(StringID(), nullable=False, primary_key=True)
map_index: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, server_default="-1")
key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)

dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False)
task_id: Mapped[str] = mapped_column(StringID(), nullable=False)
map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default="-1")
key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False)

dag_id: Mapped[str] = mapped_column(StringID(), nullable=False)
run_id: Mapped[str] = mapped_column(StringID(), nullable=False)

value: Mapped[str] = mapped_column(Text().with_variant(MEDIUMTEXT, "mysql"), nullable=False)
updated_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False)
# Optional override for early expiry. When set, garbage collection deletes this row when
# expires_at < now(), even if updated_at is recent. NULL means no early expiry —
# the row is still cleaned up by the global `updated_at + default_retention_days` check.
# Populated via task_state.set(retention_days=N) for keys that should expire differently
# than the deployment wide default.
expires_at: Mapped[datetime | None] = mapped_column(UtcDateTime, nullable=True)

__table_args__ = (
PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"),
UniqueConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_uq"),
ForeignKeyConstraint(
["dag_run_id"],
["dag_run.id"],
name="task_state_dag_run_fkey",
ondelete="CASCADE",
),
Index("idx_task_state_lookup", "dag_id", "run_id", "task_id", "map_index"),
Index("idx_task_state_expires_at", "expires_at"),
)
73 changes: 70 additions & 3 deletions airflow-core/src/airflow/state/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@

from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from typing import TYPE_CHECKING

import structlog
from sqlalchemy import delete, select

from airflow._shared.state import AssetScope, BaseStateBackend, StateScope, TaskScope
from airflow._shared.timezones import timezone
from airflow.configuration import conf
from airflow.models.asset_state import AssetStateModel
from airflow.models.dagrun import DagRun
from airflow.models.task_state import TaskStateModel
from airflow.typing_compat import assert_never
from airflow.utils.session import NEW_SESSION, create_session_async, provide_session
from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session
from airflow.utils.sqlalchemy import get_dialect_name

if TYPE_CHECKING:
Expand All @@ -40,6 +43,21 @@
from sqlalchemy.orm import Session


log = structlog.get_logger(__name__)


def _compute_expires_at(now: datetime) -> datetime | None:
"""
Return the expiry timestamp for a new task state row based on config.

Returns None if default_retention_days is 0 (never expires).
"""
retention_days = conf.getint("state_store", "default_retention_days")
if retention_days <= 0:
return None
return now + timedelta(days=retention_days)


@asynccontextmanager
async def _async_session(session: AsyncSession | None) -> AsyncGenerator[AsyncSession, None]:
"""Use provided async session or create a new one."""
Expand Down Expand Up @@ -200,6 +218,7 @@ def _set_task_state(self, scope: TaskScope, key: str, value: str, *, session: Se
if dag_run_id is None:
raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} run_id={scope.run_id!r}")
now = timezone.utcnow()
expires_at = _compute_expires_at(now)
values = dict(
dag_run_id=dag_run_id,
dag_id=scope.dag_id,
Expand All @@ -209,13 +228,14 @@ def _set_task_state(self, scope: TaskScope, key: str, value: str, *, session: Se
key=key,
value=value,
updated_at=now,
expires_at=expires_at,
)
stmt = _build_upsert_stmt(
get_dialect_name(session),
TaskStateModel,
["dag_run_id", "task_id", "map_index", "key"],
values,
dict(value=value, updated_at=now),
dict(value=value, updated_at=now, expires_at=expires_at),
)
session.execute(stmt)

Expand Down Expand Up @@ -276,6 +296,51 @@ def _clear_asset_state(self, scope: AssetScope, *, session: Session) -> None:
)
)

def cleanup(self) -> None:
"""
Remove expired task state rows.

``expires_at`` is set at write time on every ``set()`` call, so cleanup is a single
``WHERE expires_at < now()`` pass. Rows with ``expires_at=NULL`` (default_retention_days=0)
are never deleted. Batching is configurable via ``[state_store] state_cleanup_batch_size``.
"""
batch_size = conf.getint("state_store", "state_cleanup_batch_size")
now = timezone.utcnow()

def _delete_batched(where_clause) -> int:
total = 0
with create_session() as session:
while True:
id_query = select(TaskStateModel.id).where(where_clause)
if batch_size > 0:
id_query = id_query.limit(batch_size)
ids = session.scalars(id_query).all()
if not ids:
break
session.execute(delete(TaskStateModel).where(TaskStateModel.id.in_(ids)))
session.commit()
total += len(ids)
if batch_size <= 0 or len(ids) < batch_size:
break
return total

deleted = _delete_batched(TaskStateModel.expires_at < now)
Comment thread
amoghrajesh marked this conversation as resolved.
log.info("Deleted expired task_state rows", rows_deleted=deleted)

def _summary_dry_run(self) -> dict[str, list]:
"""Return rows that would be deleted by cleanup() without deleting anything."""
now = timezone.utcnow()
cols = (
TaskStateModel.dag_id,
TaskStateModel.run_id,
TaskStateModel.task_id,
TaskStateModel.map_index,
TaskStateModel.key,
)
with create_session() as session:
expired = session.execute(select(*cols).where(TaskStateModel.expires_at < now)).all()
return {"expired": list(expired)}

async def _aget_task_state(self, scope: TaskScope, key: str, *, session: AsyncSession) -> str | None:
row = await session.scalar(
select(TaskStateModel).where(
Expand All @@ -300,6 +365,7 @@ async def _aset_task_state(
if dag_run_id is None:
raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} run_id={scope.run_id!r}")
now = timezone.utcnow()
expires_at = _compute_expires_at(now)
values = dict(
dag_run_id=dag_run_id,
dag_id=scope.dag_id,
Expand All @@ -309,14 +375,15 @@ async def _aset_task_state(
key=key,
value=value,
updated_at=now,
expires_at=expires_at,
)
# get_dialect_name expects a sync Session; sync_session is the underlying Session the async wrapper delegates to
stmt = _build_upsert_stmt(
get_dialect_name(session.sync_session),
TaskStateModel,
["dag_run_id", "task_id", "map_index", "key"],
values,
dict(value=value, updated_at=now),
dict(value=value, updated_at=now, expires_at=expires_at),
)
await session.execute(stmt)

Expand Down
Loading
Loading