Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from __future__ import annotations

from pydantic import Field

from airflow.api_fastapi.core_api.base import StrictBaseModel


Expand All @@ -30,3 +32,4 @@ class TaskStatePutBody(StrictBaseModel):
"""Request body for setting a task state value."""

value: str
retention_days: int | None = Field(default=None, ge=0)
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def set_task_state(
) -> None:
"""Set a task state key, creating or updating the row."""
scope = _get_task_scope_for_ti(task_instance_id, session)
get_state_backend().set(scope, key, body.value, session=session) # type: ignore[call-arg] # @provide_session adds session kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it
get_state_backend().set(scope, key, body.value, retention_days=body.retention_days, session=session) # type: ignore[call-arg] # @provide_session adds session kwarg at runtime; BaseStateBackend signature omits it so mypy can't see it


@router.delete("/{task_instance_id}/{key}", status_code=status.HTTP_204_NO_CONTENT)
Expand Down
18 changes: 18 additions & 0 deletions airflow-core/src/airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,19 @@ class GroupCommand(NamedTuple):
args=(ARG_OUTPUT, ARG_VERBOSE),
),
)
STATE_STORE_COMMANDS = (
ActionCommand(
name="cleanup",
help="Remove expired stored state via the configured state backend",
description=(
"Reads [state_store] default_retention_days from config and deletes task_state rows "
"older than the configured threshold. Use --dry-run to preview without deleting."
),
func=lazy_load_command("airflow.cli.commands.state_store_command.cleanup"),
args=(ARG_DB_DRY_RUN, ARG_VERBOSE),
),
)

DB_COMMANDS = (
ActionCommand(
name="check-migrations",
Expand Down Expand Up @@ -2108,6 +2121,11 @@ class GroupCommand(NamedTuple):
help="Display providers",
subcommands=PROVIDERS_COMMANDS,
),
GroupCommand(
name="state-store",
help="Manage task and asset state storage",
subcommands=STATE_STORE_COMMANDS,
),
ActionCommand(
name="rotate-fernet-key",
func=lazy_load_command("airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key"),
Expand Down
50 changes: 50 additions & 0 deletions airflow-core/src/airflow/cli/commands/state_store_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging

from airflow.state import get_state_backend
from airflow.state.metastore import MetastoreStateBackend

log = logging.getLogger(__name__)

# Other state operations (list, get, delete per key) will be added here in the future.


def cleanup(args) -> None:
"""Remove expired task state rows via the configured state backend."""
backend = get_state_backend()

if args.dry_run:
if isinstance(backend, MetastoreStateBackend):
summary = backend._summary_dry_run_()
expired = summary["expired"]
if not expired:
print("Nothing to delete.")
return
print(f"Would delete {len(expired)} task state row(s):\n")
for dag_id, run_id, task_id, map_index, key in expired:
print(
f" Dag {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}"
)
else:
print("Custom backend configured — cannot preview rows.")
return

log.info("Running state store cleanup")
backend.cleanup()
18 changes: 18 additions & 0 deletions airflow-core/src/airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3025,6 +3025,24 @@ state_store:
type: string
example: "mypackage.state.CustomStateBackend"
default: "airflow.state.metastore.MetastoreStateBackend"
default_retention_days:
description: |
Number of days to retain task state after their last update.
Rows older than this are removed by the scheduler's periodic cleanup.
This config does not affect asset_state rows.
Set to 0 to disable time-based cleanup entirely.
version_added: 3.3.0
type: integer
example: "7"
default: "30"
state_cleanup_batch_size:
description: |
Number of rows deleted per batch during cleanup. Defaults to 0 (no batching).
Tune this on deployments with large task_state tables to improve performance per transaction.
version_added: 3.3.0
type: integer
example: "10000"
default: "0"

profiling:
description: |
Expand Down
32 changes: 31 additions & 1 deletion airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,20 @@
from itertools import groupby
from typing import TYPE_CHECKING, Any, cast

from sqlalchemy import CTE, and_, case, delete, exists, func, inspect, or_, select, text, tuple_, update
from sqlalchemy import (
CTE,
and_,
case,
delete,
exists,
func,
inspect,
or_,
select,
text,
tuple_,
update,
)
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, selectinload
from sqlalchemy.sql import expression
Expand Down Expand Up @@ -70,6 +83,7 @@
TaskInletAssetReference,
TaskOutletAssetReference,
)
from airflow.models.asset_state import AssetStateModel
from airflow.models.backfill import Backfill, BackfillDagRun
from airflow.models.callback import Callback, CallbackType, ExecutorCallback
from airflow.models.dag import DagModel
Expand Down Expand Up @@ -3092,6 +3106,7 @@ def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None:

self._orphan_unreferenced_assets(orphan_query, session=session)
self._activate_referenced_assets(activate_query, session=session)
self._cleanup_orphaned_asset_state(session=session)

@staticmethod
def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) -> None:
Expand Down Expand Up @@ -3200,6 +3215,21 @@ def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]:
session.add(warning)
existing_warned_dag_ids.add(warning.dag_id)

@staticmethod
def _cleanup_orphaned_asset_state(*, session: Session) -> None:
"""
Delete asset_state rows for assets no longer active in any Dag.

When _orphan_unreferenced_assets removes an asset from asset_active, its
asset_state rows become unreachable — no task can write to them anymore.
This runs in the same pass as asset orphanage to keep the table clean.
"""
active_asset_ids = select(AssetModel.id).join(
AssetActive,
(AssetActive.name == AssetModel.name) & (AssetActive.uri == AssetModel.uri),
)
session.execute(delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids)))

def _executor_to_workloads(
self,
workloads: Iterable[SchedulerWorkload],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def upgrade():
)
op.create_table(
"task_state",
sa.Column("id", sa.Integer(), nullable=False, autoincrement=True),
sa.Column("dag_run_id", sa.Integer(), nullable=False),
sa.Column("task_id", StringID(), nullable=False),
sa.Column("map_index", sa.Integer(), server_default="-1", nullable=False),
Expand All @@ -65,20 +66,24 @@ def upgrade():
sa.Column("run_id", StringID(), nullable=False),
sa.Column("value", sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"), nullable=False),
sa.Column("updated_at", UtcDateTime(), nullable=False),
sa.Column("expires_at", UtcDateTime(), nullable=True),
sa.ForeignKeyConstraint(
["dag_run_id"], ["dag_run.id"], name="task_state_dag_run_fkey", ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"),
sa.PrimaryKeyConstraint("id", name="task_state_pkey"),
sa.UniqueConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_uq"),
)
with op.batch_alter_table("task_state", schema=None) as batch_op:
batch_op.create_index(
"idx_task_state_lookup", ["dag_id", "run_id", "task_id", "map_index"], unique=False
)
batch_op.create_index("idx_task_state_expires_at", ["expires_at"], unique=False)


def downgrade():
"""Unapply add task_state and asset_state tables."""
with op.batch_alter_table("task_state", schema=None) as batch_op:
batch_op.drop_index("idx_task_state_expires_at")
batch_op.drop_index("idx_task_state_lookup")

op.drop_table("task_state")
Expand Down
21 changes: 15 additions & 6 deletions airflow-core/src/airflow/models/task_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from datetime import datetime

from sqlalchemy import ForeignKeyConstraint, Index, Integer, PrimaryKeyConstraint, String, Text
from sqlalchemy import ForeignKeyConstraint, Index, Integer, String, Text, UniqueConstraint
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.orm import Mapped, mapped_column

Expand All @@ -39,24 +39,33 @@ class TaskStateModel(Base):

__tablename__ = "task_state"

dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False, primary_key=True)
task_id: Mapped[str] = mapped_column(StringID(), nullable=False, primary_key=True)
map_index: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, server_default="-1")
key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)

dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False)
task_id: Mapped[str] = mapped_column(StringID(), nullable=False)
map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default="-1")
key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False)

dag_id: Mapped[str] = mapped_column(StringID(), nullable=False)
run_id: Mapped[str] = mapped_column(StringID(), nullable=False)

value: Mapped[str] = mapped_column(Text().with_variant(MEDIUMTEXT, "mysql"), nullable=False)
updated_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False)
# Optional override for early expiry. When set, garbage collection deletes this row when
# expires_at < now(), even if updated_at is recent. NULL means no early expiry —
# the row is still cleaned up by the global `updated_at + default_retention_days` check.
# Populated via task_state.set(retention_days=N) for keys that should expire differently
# than the deployment wide default.
expires_at: Mapped[datetime | None] = mapped_column(UtcDateTime, nullable=True)

__table_args__ = (
PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"),
UniqueConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_uq"),
ForeignKeyConstraint(
["dag_run_id"],
["dag_run.id"],
name="task_state_dag_run_fkey",
ondelete="CASCADE",
),
Index("idx_task_state_lookup", "dag_id", "run_id", "task_id", "map_index"),
Index("idx_task_state_expires_at", "expires_at"),
)
Loading
Loading