diff --git a/Makefile b/Makefile index 3892731352..9c41f62500 100644 --- a/Makefile +++ b/Makefile @@ -49,11 +49,13 @@ install-dev-dbt-%: $(MAKE) install-dev; \ if [ "$$version" = "1.6.0" ]; then \ echo "Applying overrides for dbt 1.6.0"; \ - $(PIP) install 'pydantic>=2.0.0' 'google-cloud-bigquery==3.30.0' 'databricks-sdk==0.28.0' --reinstall; \ + $(PIP) install 'pydantic>=2.0.0' 'google-cloud-bigquery==3.30.0' 'databricks-sdk==0.28.0' \ + 'pyOpenSSL>=24.0.0' --reinstall; \ fi; \ if [ "$$version" = "1.7.0" ]; then \ echo "Applying overrides for dbt 1.7.0"; \ - $(PIP) install 'databricks-sdk==0.28.0' --reinstall; \ + $(PIP) install 'databricks-sdk==0.28.0' \ + 'pyOpenSSL>=24.0.0' --reinstall; \ fi; \ if [ "$$version" = "1.5.0" ]; then \ echo "Applying overrides for dbt 1.5.0"; \ diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index c19d2ca629..7e4fc841cc 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -620,6 +620,11 @@ def run(ctx: click.Context, environment: t.Optional[str] = None, **kwargs: t.Any is_flag=True, help="Wait for the environment to be deleted before returning. If not specified, the environment will be deleted asynchronously by the janitor process. This option requires a connection to the data warehouse.", ) +@click.option( + "--cleanup-snapshots", + is_flag=True, + help="After invalidating, immediately delete physical snapshot tables that are exclusively owned by this environment (not referenced by any other environment). Cleanup runs synchronously regardless of --sync.", +) @click.pass_context @error_handler @cli_analytics diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 14e37e1313..6659efd38c 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -99,6 +99,7 @@ Snapshot, SnapshotEvaluator, SnapshotFingerprint, + SnapshotId, missing_intervals, to_table_mapping, ) @@ -108,7 +109,11 @@ StateReader, StateSync, ) -from sqlmesh.core.janitor import cleanup_expired_views, delete_expired_snapshots +from sqlmesh.core.janitor import ( + cleanup_expired_views, + delete_expired_snapshots, + delete_snapshots_for_environment, +) from sqlmesh.core.table_diff import TableDiff from sqlmesh.core.test import ( ModelTextTestResult, @@ -1835,18 +1840,50 @@ def apply( ) @python_api_analytics - def invalidate_environment(self, name: str, sync: bool = False) -> None: + def invalidate_environment( + self, name: str, sync: bool = False, cleanup_snapshots: bool = False + ) -> None: """Invalidates the target environment by setting its expiration timestamp to now. Args: name: The name of the environment to invalidate. sync: If True, the call blocks until the environment is deleted. Otherwise, the environment will be deleted asynchronously by the janitor process. + cleanup_snapshots: If True, immediately deletes physical snapshot tables that are exclusively + owned by this environment (not referenced by any other environment). Cleanup runs + synchronously regardless of --sync. """ name = Environment.sanitize_name(name) + sync = sync or cleanup_snapshots + + target_snapshot_ids: t.Set[SnapshotId] = set() + if cleanup_snapshots: + # Capture snapshot IDs before invalidation so we can scope the cleanup afterwards. + env = self.state_sync.get_environment(name) + if env is None: + logger.warning("Environment '%s' does not exist; skipping snapshot cleanup.", name) + return + target_snapshot_ids = {s.snapshot_id for s in env.snapshots} + self.state_sync.invalidate_environment(name) + if sync: self._cleanup_environments(name=name) + if cleanup_snapshots and target_snapshot_ids: + failures = delete_snapshots_for_environment( + self.state_sync, + self.snapshot_evaluator, + target_snapshot_ids, + console=self.console, + ) + if failures: + summary = "\n".join(failures) + if self.config.janitor.warn_on_delete_failure: + self.console.log_warning( + f"Snapshot cleanup completed with failures:\n{summary}" + ) + else: + raise SQLMeshError(f"Snapshot cleanup completed with failures:\n{summary}") self.console.log_success(f"Environment '{name}' deleted.") else: self.console.log_success(f"Environment '{name}' invalidated.") diff --git a/sqlmesh/core/janitor.py b/sqlmesh/core/janitor.py index 92d889e276..fc95566361 100644 --- a/sqlmesh/core/janitor.py +++ b/sqlmesh/core/janitor.py @@ -8,7 +8,7 @@ from sqlmesh.core.console import Console from sqlmesh.core.dialect import schema_ from sqlmesh.core.environment import Environment -from sqlmesh.core.snapshot import SnapshotEvaluator +from sqlmesh.core.snapshot import SnapshotEvaluator, SnapshotId from sqlmesh.core.state_sync import StateSync from sqlmesh.core.state_sync.common import ( logger, @@ -193,3 +193,72 @@ def delete_expired_snapshots( failures.append(message) logger.info("Cleaned up %s expired snapshots", num_expired_snapshots) return failures + + +def delete_snapshots_for_environment( + state_sync: StateSync, + snapshot_evaluator: SnapshotEvaluator, + target_snapshot_ids: t.Collection[SnapshotId], + *, + force_delete: bool = False, + console: t.Optional[Console] = None, +) -> t.List[str]: + """Delete snapshots that are exclusively owned by a specific (now-deleted) environment. + + This performs a scoped cleanup: only the provided snapshot IDs are considered for deletion, + and only those that are not referenced by any remaining active environment will be removed. + + Args: + state_sync: StateSync instance to query and delete snapshot state from. + snapshot_evaluator: SnapshotEvaluator instance to clean up physical tables. + target_snapshot_ids: The snapshot IDs to consider for deletion (typically from the + environment that was just invalidated/deleted). + force_delete: If True, delete snapshot state records even when physical table cleanup fails. + console: Optional console for reporting progress. + + Returns: + List of failure messages encountered during cleanup. + """ + if not target_snapshot_ids: + return [] + + failures: t.List[str] = [] + batch = state_sync.get_expired_snapshots( + ignore_ttl=True, + batch_range=ExpiredBatchRange.all_batch_range(), + target_snapshot_ids=target_snapshot_ids, + ) + if batch is None: + return failures + + logger.info( + "Cleaning up %s snapshots exclusively owned by invalidated environment", + len(batch.expired_snapshot_ids), + ) + + cleanup_succeeded = True + if batch.cleanup_tasks: + try: + snapshot_evaluator.cleanup( + target_snapshots=batch.cleanup_tasks, + on_complete=console.update_cleanup_progress if console else None, + ) + except Exception as failed_drops: + message = f"Failed to clean up: {failed_drops}" + logger.warning(message) + failures.append(message) + cleanup_succeeded = False + + if cleanup_succeeded or force_delete: + try: + state_sync.delete_snapshots(batch.expired_snapshot_ids) + logger.info( + "Cleaned up %s snapshots from invalidated environment", + len(batch.expired_snapshot_ids), + ) + except Exception as e: + message = f"Failed to delete snapshot state records: {e}" + logger.warning(message) + failures.append(message) + + return failures diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 5c35be5ccb..6f5023304f 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -308,6 +308,7 @@ def get_expired_snapshots( batch_range: ExpiredBatchRange, current_ts: t.Optional[int] = None, ignore_ttl: bool = False, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, ) -> t.Optional[ExpiredSnapshotBatch]: """Returns a single batch of expired snapshots ordered by (updated_ts, name, identifier). @@ -315,6 +316,8 @@ def get_expired_snapshots( current_ts: Timestamp used to evaluate expiration. ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced). batch_range: The range of the batch to fetch. + target_snapshot_ids: If provided, only consider snapshots with these IDs. Useful for + scoped cleanup after environment invalidation. Returns: A batch describing expired snapshots or None if no snapshots are pending cleanup. @@ -368,6 +371,7 @@ def delete_expired_snapshots( batch_range: ExpiredBatchRange, ignore_ttl: bool = False, current_ts: t.Optional[int] = None, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, ) -> None: """Removes expired snapshots. @@ -379,6 +383,8 @@ def delete_expired_snapshots( ignore_ttl: Ignore the TTL on the snapshot when considering it expired. This has the effect of deleting all snapshots that are not referenced in any environment current_ts: Timestamp used to evaluate expiration. + target_snapshot_ids: If provided, only delete snapshots with these IDs. Useful for + scoped cleanup after environment invalidation. """ @abc.abstractmethod diff --git a/sqlmesh/core/state_sync/cache.py b/sqlmesh/core/state_sync/cache.py index 77f3fc6ba5..edf74e03f9 100644 --- a/sqlmesh/core/state_sync/cache.py +++ b/sqlmesh/core/state_sync/cache.py @@ -113,12 +113,14 @@ def delete_expired_snapshots( batch_range: ExpiredBatchRange, ignore_ttl: bool = False, current_ts: t.Optional[int] = None, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, ) -> None: self.snapshot_cache.clear() self.state_sync.delete_expired_snapshots( batch_range=batch_range, ignore_ttl=ignore_ttl, current_ts=current_ts, + target_snapshot_ids=target_snapshot_ids, ) def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 572e54b7f1..8fb732e17c 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -267,6 +267,7 @@ def get_expired_snapshots( batch_range: ExpiredBatchRange, current_ts: t.Optional[int] = None, ignore_ttl: bool = False, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, ) -> t.Optional[ExpiredSnapshotBatch]: current_ts = current_ts or now_timestamp() return self.snapshot_state.get_expired_snapshots( @@ -274,6 +275,7 @@ def get_expired_snapshots( current_ts=current_ts, ignore_ttl=ignore_ttl, batch_range=batch_range, + target_snapshot_ids=target_snapshot_ids, ) def get_expired_environments( @@ -287,11 +289,13 @@ def delete_expired_snapshots( batch_range: ExpiredBatchRange, ignore_ttl: bool = False, current_ts: t.Optional[int] = None, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, ) -> None: batch = self.get_expired_snapshots( ignore_ttl=ignore_ttl, current_ts=current_ts, batch_range=batch_range, + target_snapshot_ids=target_snapshot_ids, ) if batch and batch.expired_snapshot_ids: self.snapshot_state.delete_snapshots(batch.expired_snapshot_ids) diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 9b4337b504..287a69013b 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -170,6 +170,7 @@ def get_expired_snapshots( current_ts: int, ignore_ttl: bool, batch_range: ExpiredBatchRange, + target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None, ) -> t.Optional[ExpiredSnapshotBatch]: expired_query = exp.select("name", "identifier", "version", "updated_ts").from_( self.snapshots_table @@ -180,6 +181,16 @@ def get_expired_snapshots( (exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts ) + if target_snapshot_ids is not None: + target_conditions = list( + snapshot_id_filter( + self.engine_adapter, + target_snapshot_ids, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ) + ) + expired_query = expired_query.where(exp.or_(*target_conditions)) + expired_query = expired_query.where(batch_range.where_filter) promoted_snapshot_ids = { diff --git a/sqlmesh/utils/__init__.py b/sqlmesh/utils/__init__.py index 5b1b077216..59605893ba 100644 --- a/sqlmesh/utils/__init__.py +++ b/sqlmesh/utils/__init__.py @@ -246,9 +246,7 @@ def wrap(*args: t.Any, **kwargs: t.Any) -> t.Any: class classproperty(property): - """ - Similar to a normal property but works for class methods - """ + """Similar to a normal property but works for class methods""" def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: return classmethod(self.fget).__get__(None, owner)() # type: ignore diff --git a/tests/core/integration/test_aux_commands.py b/tests/core/integration/test_aux_commands.py index 7de585576d..04cd01607f 100644 --- a/tests/core/integration/test_aux_commands.py +++ b/tests/core/integration/test_aux_commands.py @@ -481,6 +481,76 @@ def test_invalidating_environment(sushi_context: Context): assert start_schemas - schemas_after_janitor == {"sushi__dev"} +def test_invalidate_environment_cleanup_snapshots_scoped(tmp_path: Path): + """Test that --cleanup-snapshots only deletes snapshots exclusively owned by the invalidated env.""" + models_dir = tmp_path / "models" + models_dir.mkdir() + (models_dir / "model1.sql").write_text("MODEL(name test.model1, kind FULL); SELECT 1 AS col") + (models_dir / "model2.sql").write_text("MODEL(name test.model2, kind FULL); SELECT 2 AS col") + + ctx = Context( + paths=[tmp_path], + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), + ) + + # Apply both models to prod and dev. + ctx.plan("prod", no_prompts=True, auto_apply=True) + ctx.plan("dev", no_prompts=True, auto_apply=True, include_unmodified=True) + + prod_env = ctx.state_sync.get_environment("prod") + dev_env = ctx.state_sync.get_environment("dev") + assert prod_env is not None + assert dev_env is not None + + prod_snapshot_ids = {s.snapshot_id for s in prod_env.snapshots} + dev_snapshot_ids = {s.snapshot_id for s in dev_env.snapshots} + + # In a virtual environment, dev shares snapshots with prod. + # Shared snapshots must NOT be deleted when invalidating dev with --cleanup-snapshots. + shared_snapshot_ids = prod_snapshot_ids & dev_snapshot_ids + + ctx.invalidate_environment("dev", cleanup_snapshots=True) + + # The dev environment record should be gone. + assert ctx.state_sync.get_environment("dev") is None + + # Shared snapshots (also in prod) must still exist. + remaining_snapshots = ctx.state_sync.get_snapshots(list(shared_snapshot_ids)) + assert set(remaining_snapshots.keys()) == shared_snapshot_ids + + # Prod environment should be unaffected. + assert ctx.state_sync.get_environment("prod") is not None + + +def test_invalidate_environment_cleanup_snapshots_exclusive(tmp_path: Path): + """Test that --cleanup-snapshots deletes snapshots exclusively owned by the invalidated env.""" + models_dir = tmp_path / "models" + models_dir.mkdir() + (models_dir / "model1.sql").write_text("MODEL(name test.model1, kind FULL); SELECT 1 AS col") + + ctx = Context( + paths=[tmp_path], + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), + ) + + # Apply model1 to dev only (not prod). These snapshots will be exclusively owned by dev. + ctx.plan("dev", no_prompts=True, auto_apply=True) + + dev_env = ctx.state_sync.get_environment("dev") + assert dev_env is not None + dev_snapshot_ids = {s.snapshot_id for s in dev_env.snapshots} + assert dev_snapshot_ids + + ctx.invalidate_environment("dev", cleanup_snapshots=True) + + # The dev environment record should be gone. + assert ctx.state_sync.get_environment("dev") is None + + # All dev-exclusive snapshots should have been deleted. + remaining_snapshots = ctx.state_sync.get_snapshots(list(dev_snapshot_ids)) + assert not remaining_snapshots + + @time_machine.travel("2023-01-08 15:00:00 UTC") def test_evaluate_uncategorized_snapshot(init_and_plan_context: t.Callable): context, plan = init_and_plan_context("examples/sushi") diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py index 348a883fd5..582fa296d1 100644 --- a/tests/core/state_sync/test_state_sync.py +++ b/tests/core/state_sync/test_state_sync.py @@ -4220,3 +4220,144 @@ def test_state_version_is_too_old( match="The current state belongs to an old version of SQLMesh that is no longer supported. Please upgrade to 0.134.0 first before upgrading to.*", ): state_sync.migrate(skip_backup=True) + + +def test_get_expired_snapshots_scoped_to_target_ids( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +) -> None: + """Test that get_expired_snapshots with target_snapshot_ids only returns snapshots in the target set.""" + now_ts = now_timestamp() + + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot_a.ttl = "in 10 seconds" + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_a.updated_ts = now_ts - 15000 + + snapshot_b = make_snapshot( + SqlModel( + name="b", + query=parse_one("select b, ds"), + ), + ) + snapshot_b.ttl = "in 10 seconds" + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.updated_ts = now_ts - 15000 + + state_sync.push_snapshots([snapshot_a, snapshot_b]) + + # Both snapshots are expired (no active environments). + # When scoped to only snapshot_a, only snapshot_a should be returned. + batch = state_sync.get_expired_snapshots( + ignore_ttl=True, + batch_range=ExpiredBatchRange.all_batch_range(), + target_snapshot_ids=[snapshot_a.snapshot_id], + ) + assert batch is not None + assert batch.expired_snapshot_ids == {snapshot_a.snapshot_id} + assert [t.snapshot.name for t in batch.cleanup_tasks] == [snapshot_a.name] + + # snapshot_b should still exist because it was not in the target set. + batch_all = state_sync.get_expired_snapshots( + ignore_ttl=True, + batch_range=ExpiredBatchRange.all_batch_range(), + ) + assert batch_all is not None + assert snapshot_b.snapshot_id in batch_all.expired_snapshot_ids + + +def test_get_expired_snapshots_scoped_excludes_shared_snapshots( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +) -> None: + """Test that scoped cleanup respects protection: snapshots shared with other environments are not deleted.""" + now_ts = now_timestamp() + + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot_a.ttl = "in 10 seconds" + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_a.updated_ts = now_ts - 15000 + + snapshot_b = make_snapshot( + SqlModel( + name="b", + query=parse_one("select b, ds"), + ), + ) + snapshot_b.ttl = "in 10 seconds" + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.updated_ts = now_ts - 15000 + + state_sync.push_snapshots([snapshot_a, snapshot_b]) + + # Promote snapshot_b to another active environment (prod-like). + prod_env = Environment( + name="prod", + snapshots=[snapshot_b.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + ) + state_sync.promote(prod_env) + state_sync.finalize(prod_env) + + # Even though snapshot_b is in the target set, it should NOT be returned + # because it is still referenced by prod_env. + batch = state_sync.get_expired_snapshots( + ignore_ttl=True, + batch_range=ExpiredBatchRange.all_batch_range(), + target_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + ) + assert batch is not None + # Only snapshot_a is exclusively owned (not referenced by any active environment). + assert batch.expired_snapshot_ids == {snapshot_a.snapshot_id} + assert [t.snapshot.name for t in batch.cleanup_tasks] == [snapshot_a.name] + + +def test_delete_expired_snapshots_scoped( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +) -> None: + """Test that delete_expired_snapshots with target_snapshot_ids only deletes scoped snapshots.""" + now_ts = now_timestamp() + + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot_a.ttl = "in 10 seconds" + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_a.updated_ts = now_ts - 15000 + + snapshot_b = make_snapshot( + SqlModel( + name="b", + query=parse_one("select b, ds"), + ), + ) + snapshot_b.ttl = "in 10 seconds" + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.updated_ts = now_ts - 15000 + + state_sync.push_snapshots([snapshot_a, snapshot_b]) + + # Delete only snapshot_a via scoped cleanup. + state_sync.delete_expired_snapshots( + batch_range=ExpiredBatchRange.all_batch_range(), + ignore_ttl=True, + target_snapshot_ids=[snapshot_a.snapshot_id], + ) + + # snapshot_a should be deleted, snapshot_b should remain. + assert not state_sync.get_snapshots([snapshot_a]) + assert state_sync.get_snapshots([snapshot_b])