Skip to content

Commit 29ba3a0

Browse files
authored
fix: Consolidate auto change categorization in the plan definition (#604)
1 parent 9839223 commit 29ba3a0

3 files changed

Lines changed: 91 additions & 58 deletions

File tree

sqlmesh/core/context.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
Snapshot,
6363
SnapshotEvaluator,
6464
SnapshotFingerprint,
65-
categorize_change,
6665
to_table_mapping,
6766
)
6867
from sqlmesh.core.state_sync import StateReader, StateSync
@@ -637,18 +636,10 @@ def plan(
637636
is_dev=environment != c.PROD,
638637
forward_only=forward_only,
639638
environment_ttl=self.config.environment_ttl,
639+
categorizer_config=self.config.auto_categorize_changes,
640+
auto_categorization_enabled=not no_auto_categorization,
640641
)
641642

642-
if not no_auto_categorization and not forward_only:
643-
# Attempt to automatically determine and assign change categories.
644-
for new, old in plan.context_diff.modified_snapshots.values():
645-
if new in plan.directly_modified and plan.is_new_snapshot(new):
646-
change_category = categorize_change(
647-
new, old, config=self.config.auto_categorize_changes
648-
)
649-
if change_category is not None:
650-
plan.set_choice(new, change_category)
651-
652643
if not no_prompts:
653644
self.console.plan(plan, auto_apply)
654645
elif auto_apply:

sqlmesh/core/plan/definition.py

Lines changed: 67 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from __future__ import annotations
22

33
import typing as t
4-
from collections import defaultdict, deque
4+
from collections import defaultdict
55
from enum import Enum
66

77
from sqlmesh.core import scheduler
8+
from sqlmesh.core.config import CategorizerConfig
89
from sqlmesh.core.context_diff import ContextDiff
910
from sqlmesh.core.environment import Environment
1011
from sqlmesh.core.snapshot import (
1112
Intervals,
1213
Snapshot,
1314
SnapshotChangeCategory,
1415
SnapshotId,
16+
categorize_change,
1517
merge_intervals,
1618
)
1719
from sqlmesh.core.state_sync import StateReader
@@ -52,6 +54,8 @@ class Plan:
5254
is_dev: Whether this plan is for development purposes.
5355
forward_only: Whether the purpose of the plan is to make forward only changes.
5456
environment_ttl: The period of time that a development environment should exist before being deleted.
57+
categorizer_config: Auto categorization settings.
58+
auto_categorization_enabled: Whether to apply auto categorization.
5559
"""
5660

5761
def __init__(
@@ -68,6 +72,8 @@ def __init__(
6872
is_dev: bool = False,
6973
forward_only: bool = False,
7074
environment_ttl: t.Optional[str] = None,
75+
categorizer_config: t.Optional[CategorizerConfig] = None,
76+
auto_categorization_enabled: bool = True,
7177
):
7278
self.context_diff = context_diff
7379
self.override_start = start is not None
@@ -78,6 +84,8 @@ def __init__(
7884
self.is_dev = is_dev
7985
self.forward_only = forward_only
8086
self.environment_ttl = environment_ttl
87+
self.categorizer_config = categorizer_config or CategorizerConfig()
88+
self.auto_categorization_enabled = auto_categorization_enabled
8189
self._start = start if start or not (is_dev and forward_only) else yesterday_ds()
8290
self._end = end if end or not is_dev else now()
8391
self._apply = apply
@@ -105,9 +113,11 @@ def __init__(
105113
self._ensure_no_forward_only_revert()
106114
self._ensure_no_forward_only_new_models()
107115

108-
categorized_snapshots = self._categorize_snapshots()
109-
self.directly_modified = categorized_snapshots[0]
110-
self.indirectly_modified = categorized_snapshots[1]
116+
directly_indirectly_modified = self._build_directly_and_indirectly_modified()
117+
self.directly_modified = directly_indirectly_modified[0]
118+
self.indirectly_modified = directly_indirectly_modified[1]
119+
120+
self._categorize_snapshots()
111121

112122
self._categorized: t.Optional[t.List[Snapshot]] = None
113123
self._uncategorized: t.Optional[t.List[Snapshot]] = None
@@ -360,66 +370,26 @@ def _add_restatements(self, restate_models: t.Iterable[str]) -> None:
360370
)
361371
self._restatements.update(downstream)
362372

363-
def _categorize_snapshots(self) -> t.Tuple[t.List[Snapshot], SnapshotMapping]:
364-
"""Automatically categorizes snapshots that can be automatically categorized and
365-
returns a list of added and directly modified snapshots as well as the mapping of
366-
indirectly modified snapshots.
373+
def _build_directly_and_indirectly_modified(self) -> t.Tuple[t.List[Snapshot], SnapshotMapping]:
374+
"""Builds collections of directly and inderectly modified snapshots.
367375
368376
Returns:
369377
The tuple in which the first element contains a list of added and directly modified
370378
snapshots while the second element contains a mapping of indirectly modified snapshots.
371379
"""
372-
queue = deque(self._dag.sorted())
373380
directly_modified = []
374381
all_indirectly_modified = set()
375382

376-
while queue:
377-
model_name = queue.popleft()
378-
379-
if model_name not in self.context_diff.snapshots:
380-
continue
381-
382-
upstream_model_names = self._dag.upstream(model_name)
383-
384-
if not self.forward_only:
385-
self._ensure_no_paused_forward_only_upstream(model_name, upstream_model_names)
386-
387-
snapshot = self.context_diff.snapshots[model_name]
388-
383+
for model_name, snapshot in self.context_diff.snapshots.items():
389384
if model_name in self.context_diff.modified_snapshots:
390-
if self.forward_only and self.is_new_snapshot(snapshot):
391-
# In case of the forward only plan any modifications result in reuse of the
392-
# previous version for non-seed models.
393-
# New snapshots of seed models are considered non-breaking ones.
394-
if not snapshot.is_seed_kind:
395-
snapshot.set_version(snapshot.previous_version)
396-
snapshot.change_category = SnapshotChangeCategory.FORWARD_ONLY
397-
else:
398-
snapshot.set_version()
399-
snapshot.change_category = SnapshotChangeCategory.NON_BREAKING
400-
401385
if self.context_diff.directly_modified(model_name):
402386
directly_modified.append(snapshot)
403387
else:
404388
all_indirectly_modified.add(model_name)
405-
406-
# set to breaking if an indirect child has no directly modified parents
407-
# that need a decision. this can happen when a revert to a parent causes
408-
# an indirectly modified snapshot to be created because of a new parent
409-
if not snapshot.version and not any(
410-
self.context_diff.directly_modified(upstream)
411-
and not self.context_diff.snapshots[upstream].version
412-
for upstream in upstream_model_names
413-
):
414-
snapshot.set_version()
415-
416389
elif model_name in self.context_diff.added:
417-
if self.is_new_snapshot(snapshot):
418-
snapshot.set_version()
419390
directly_modified.append(snapshot)
420391

421392
indirectly_modified: SnapshotMapping = defaultdict(set)
422-
423393
for snapshot in directly_modified:
424394
for downstream in self._dag.downstream(snapshot.name):
425395
if downstream in all_indirectly_modified:
@@ -430,6 +400,56 @@ def _categorize_snapshots(self) -> t.Tuple[t.List[Snapshot], SnapshotMapping]:
430400
indirectly_modified,
431401
)
432402

403+
def _categorize_snapshots(self) -> None:
404+
"""Automatically categorizes snapshots that can be automatically categorized and
405+
returns a list of added and directly modified snapshots as well as the mapping of
406+
indirectly modified snapshots.
407+
"""
408+
for model_name, snapshot in self.context_diff.snapshots.items():
409+
upstream_model_names = self._dag.upstream(model_name)
410+
411+
if not self.forward_only:
412+
self._ensure_no_paused_forward_only_upstream(model_name, upstream_model_names)
413+
414+
if model_name in self.context_diff.modified_snapshots:
415+
is_directly_modified = self.context_diff.directly_modified(model_name)
416+
417+
if self.is_new_snapshot(snapshot):
418+
if self.forward_only:
419+
# In case of the forward only plan any modifications result in reuse of the
420+
# previous version for non-seed models.
421+
# New snapshots of seed models are considered non-breaking ones.
422+
if not snapshot.is_seed_kind:
423+
snapshot.set_version(snapshot.previous_version)
424+
snapshot.change_category = SnapshotChangeCategory.FORWARD_ONLY
425+
else:
426+
snapshot.set_version()
427+
snapshot.change_category = SnapshotChangeCategory.NON_BREAKING
428+
elif self.auto_categorization_enabled and is_directly_modified:
429+
new, old = self.context_diff.modified_snapshots[model_name]
430+
change_category = categorize_change(
431+
new, old, config=self.categorizer_config
432+
)
433+
if change_category is not None:
434+
self.set_choice(new, change_category)
435+
436+
# set to breaking if an indirect child has no directly modified parents
437+
# that need a decision. this can happen when a revert to a parent causes
438+
# an indirectly modified snapshot to be created because of a new parent
439+
if (
440+
not is_directly_modified
441+
and not snapshot.version
442+
and not any(
443+
self.context_diff.directly_modified(upstream)
444+
and not self.context_diff.snapshots[upstream].version
445+
for upstream in upstream_model_names
446+
)
447+
):
448+
snapshot.set_version()
449+
450+
elif model_name in self.context_diff.added and self.is_new_snapshot(snapshot):
451+
snapshot.set_version()
452+
433453
def _ensure_no_paused_forward_only_upstream(
434454
self, model_name: str, upstream_model_names: t.Iterable[str]
435455
) -> None:

tests/core/test_plan.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,25 @@ def test_start_inference(make_snapshot, mocker: MockerFixture):
355355
assert snapshot_b.version_get_or_generate() in plan._missing_intervals
356356

357357
assert plan.start == to_timestamp("2022-01-01")
358+
359+
360+
def test_auto_categorization(make_snapshot, mocker: MockerFixture):
361+
snapshot = make_snapshot(SqlModel(name="a", query=parse_one("select 1, ds")))
362+
snapshot.set_version()
363+
364+
updated_snapshot = make_snapshot(SqlModel(name="a", query=parse_one("select 2, ds")))
365+
366+
dag = DAG[str]({"a": set()})
367+
368+
context_diff_mock = mocker.Mock()
369+
context_diff_mock.snapshots = {"a": updated_snapshot}
370+
context_diff_mock.added = set()
371+
context_diff_mock.modified_snapshots = {"a": (updated_snapshot, snapshot)}
372+
context_diff_mock.new_snapshots = {updated_snapshot.snapshot_id: updated_snapshot}
373+
374+
state_reader_mock = mocker.Mock()
375+
376+
Plan(context_diff_mock, dag, state_reader_mock)
377+
378+
assert updated_snapshot.version == updated_snapshot.fingerprint.to_version()
379+
assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING

0 commit comments

Comments
 (0)