Skip to content

Commit d1e66d3

Browse files
authored
Dbt columns to sqlmesh (#391)
* Move dbt model creation from load_model to create_sql_model * Rename parse_model to parse * Add DBT project columns to SQLMesh models * Fix test * typo * Use dict comprehension
1 parent 75c6da2 commit d1e66d3

15 files changed

Lines changed: 225 additions & 69 deletions

File tree

sqlmesh/core/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from sqlmesh.core.config import Config, load_config_from_paths
5050
from sqlmesh.core.console import Console, get_console
5151
from sqlmesh.core.context_diff import ContextDiff
52-
from sqlmesh.core.dialect import format_model_expressions, pandas_to_sql, parse_model
52+
from sqlmesh.core.dialect import format_model_expressions, pandas_to_sql, parse
5353
from sqlmesh.core.engine_adapter import EngineAdapter
5454
from sqlmesh.core.environment import Environment
5555
from sqlmesh.core.hooks import hook
@@ -552,7 +552,7 @@ def format(self) -> None:
552552
if not model.is_sql:
553553
continue
554554
with open(model._path, "r+", encoding="utf-8") as file:
555-
expressions = parse_model(file.read(), default_dialect=self.dialect)
555+
expressions = parse(file.read(), default_dialect=self.dialect)
556556
file.seek(0)
557557
file.write(format_model_expressions(expressions, model.dialect))
558558
file.truncate()

sqlmesh/core/dialect.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,10 @@ def text_diff(
383383
JINJA_PATTERN = re.compile(r"{{|{%|{#")
384384

385385

386-
def parse_model(sql: str, default_dialect: str | None = None) -> t.List[exp.Expression]:
387-
"""Parse a sql string containing a model definition.
386+
def parse(sql: str, default_dialect: str | None = None) -> t.List[exp.Expression]:
387+
"""Parse a sql string.
388388
389+
Supports parsing model definition.
389390
If a jinja block is detected, the query is stored as raw string in a Jinja node.
390391
391392
Args:

sqlmesh/core/loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from sqlmesh.core import constants as c
1717
from sqlmesh.core.audit import Audit
18-
from sqlmesh.core.dialect import parse_model
18+
from sqlmesh.core.dialect import parse
1919
from sqlmesh.core.hooks import HookRegistry, hook
2020
from sqlmesh.core.macros import MacroRegistry, macro
2121
from sqlmesh.core.model import Model, SeedModel, load_model
@@ -190,7 +190,7 @@ def _load_sql_models(
190190
self._track_file(path)
191191
with open(path, "r", encoding="utf-8") as file:
192192
try:
193-
expressions = parse_model(file.read(), default_dialect=self._context.dialect)
193+
expressions = parse(file.read(), default_dialect=self._context.dialect)
194194
except SqlglotError as ex:
195195
raise ConfigError(f"Failed to parse a model definition at '{path}': {ex}")
196196
model = load_model(
@@ -240,7 +240,7 @@ def _load_audits(self) -> UniqueKeyDict[str, Audit]:
240240
for path in self._context.glob_path(self._context.audits_directory_path, ".sql"):
241241
self._track_file(path)
242242
with open(path, "r", encoding="utf-8") as file:
243-
expressions = parse_model(file.read(), default_dialect=self._context.dialect)
243+
expressions = parse(file.read(), default_dialect=self._context.dialect)
244244
audits = Audit.load_multiple(
245245
expressions=expressions,
246246
path=path,

sqlmesh/core/model/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sqlglot.expressions import split_num_words
88
from sqlglot.helper import seq_get
99

10-
from sqlmesh.core.dialect import parse_model
10+
from sqlmesh.core.dialect import parse
1111
from sqlmesh.utils.errors import ConfigError
1212

1313

@@ -37,7 +37,7 @@ def parse_expression(
3737
return [e for e in (maybe_parse(i) for i in v) if e]
3838

3939
if isinstance(v, str):
40-
return seq_get(parse_model(v), 0)
40+
return seq_get(parse(v), 0)
4141

4242
if not v:
4343
raise ConfigError(f"Could not parse {v}")

sqlmesh/core/model/definition.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -382,11 +382,6 @@ def depends_on(self) -> t.Set[str]:
382382
self._depends_on = _find_tables(self.render_query()) - {self.name}
383383
return self._depends_on
384384

385-
@property
386-
def column_descriptions(self) -> t.Dict[str, str]:
387-
"""A dictionary of column names to annotation comments."""
388-
return {}
389-
390385
@property
391386
def columns_to_types(self) -> t.Dict[str, exp.DataType]:
392387
"""Returns the mapping of column names to types of this model."""
@@ -583,6 +578,9 @@ def columns_to_types(self) -> t.Dict[str, exp.DataType]:
583578

584579
@property
585580
def column_descriptions(self) -> t.Dict[str, str]:
581+
if self.column_descriptions_ is not None:
582+
return self.column_descriptions_
583+
586584
if self._column_descriptions is None:
587585
self._column_descriptions = {
588586
select.alias: "\n".join(comment.strip() for comment in select.comments)
@@ -847,7 +845,7 @@ def load_model(
847845
return create_sql_model(
848846
name,
849847
query,
850-
statements,
848+
statements=statements,
851849
defaults=defaults,
852850
path=path,
853851
module_path=module_path,
@@ -880,8 +878,8 @@ def load_model(
880878
def create_sql_model(
881879
name: str,
882880
query: exp.Expression,
883-
statements: t.List[exp.Expression],
884881
*,
882+
statements: t.Optional[t.List[exp.Expression]] = None,
885883
defaults: t.Optional[t.Dict[str, t.Any]] = None,
886884
path: Path = Path(),
887885
module_path: Path = Path(),
@@ -936,7 +934,7 @@ def create_sql_model(
936934
time_column_format=time_column_format,
937935
python_env=python_env,
938936
dialect=dialect,
939-
expressions=statements,
937+
expressions=statements or [],
940938
query=query,
941939
**kwargs,
942940
)

sqlmesh/core/model/meta.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class ModelMeta(PydanticModel):
5555
post: t.List[HookCall] = []
5656
depends_on_: t.Optional[t.Set[str]] = Field(default=None, alias="depends_on")
5757
columns_to_types_: t.Optional[t.Dict[str, exp.DataType]] = Field(default=None, alias="columns")
58+
column_descriptions_: t.Optional[t.Dict[str, str]]
5859
audits: t.List[AuditReference] = []
5960

6061
_croniter: t.Optional[croniter] = None
@@ -193,6 +194,11 @@ def partitioned_by(self) -> t.List[str]:
193194
time_column = [self.time_column.column] if self.time_column else []
194195
return unique([*time_column, *self.partitioned_by_])
195196

197+
@property
198+
def column_descriptions(self) -> t.Dict[str, str]:
199+
"""A dictionary of column names to annotation comments."""
200+
return self.column_descriptions_ or {}
201+
196202
def interval_unit(self, sample_size: int = 10) -> IntervalUnit:
197203
"""Returns the IntervalUnit of the model
198204

sqlmesh/dbt/column.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing as t
44

55
from pydantic import validator
6+
from sqlglot import exp, parse_one
67
from sqlglot.helper import ensure_list
78

89
from sqlmesh.dbt.common import GeneralConfig
@@ -21,6 +22,30 @@ def yaml_to_columns(
2122
return columns
2223

2324

25+
def column_types_to_sqlmesh(columns: t.Dict[str, ColumnConfig]) -> t.Dict[str, exp.DataType]:
26+
"""
27+
Get the sqlmesh column types
28+
29+
Returns:
30+
A dict of column name to exp.DataType
31+
"""
32+
return {
33+
name: parse_one(column.data_type, into=exp.DataType)
34+
for name, column in columns.items()
35+
if column.data_type
36+
}
37+
38+
39+
def column_descriptions_to_sqlmesh(columns: t.Dict[str, ColumnConfig]) -> t.Dict[str, str]:
40+
"""
41+
Get the sqlmesh column types
42+
43+
Returns:
44+
A dict of column name to description
45+
"""
46+
return {name: column.description for name, column in columns.items() if column.description}
47+
48+
2449
class ColumnConfig(GeneralConfig):
2550
"""
2651
Column configuration for a DBT project

sqlmesh/dbt/model.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,20 @@
99

1010
from sqlmesh.core import dialect as d
1111
from sqlmesh.core.config.base import UpdateStrategy
12-
from sqlmesh.core.model import Model, ModelKindName, load_model
13-
from sqlmesh.dbt.column import ColumnConfig, yaml_to_columns
12+
from sqlmesh.core.model import (
13+
IncrementalByTimeRangeKind,
14+
IncrementalByUniqueKeyKind,
15+
Model,
16+
ModelKind,
17+
ModelKindName,
18+
create_sql_model,
19+
)
20+
from sqlmesh.dbt.column import (
21+
ColumnConfig,
22+
column_descriptions_to_sqlmesh,
23+
column_types_to_sqlmesh,
24+
yaml_to_columns,
25+
)
1426
from sqlmesh.dbt.common import Dependencies, GeneralConfig
1527
from sqlmesh.dbt.macros import MacroConfig, ref_method, source_method, var_method
1628
from sqlmesh.dbt.seed import SeedConfig
@@ -108,6 +120,9 @@ def _validate_grants(cls, v: t.Dict[str, str]) -> t.Dict[str, t.List[str]]:
108120

109121
@validator("columns", pre=True)
110122
def _validate_columns(cls, v: t.Any) -> t.Dict[str, ColumnConfig]:
123+
if not isinstance(v, dict) or all(isinstance(col, ColumnConfig) for col in v.values()):
124+
return v
125+
111126
return yaml_to_columns(v)
112127

113128
_FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {
@@ -132,22 +147,6 @@ def to_sqlmesh(
132147
macros: UniqueKeyDict[str, MacroConfig],
133148
) -> Model:
134149
"""Converts the dbt model into a SQLMesh model."""
135-
expressions = d.parse_model(
136-
f"""
137-
MODEL (
138-
name {self.model_name},
139-
kind {self.model_kind},
140-
);
141-
"""
142-
+ self.sql,
143-
default_dialect="",
144-
)
145-
146-
for jinja in expressions[1:]:
147-
# find all the refs here and filter the python env?
148-
if isinstance(jinja, d.Jinja):
149-
pass
150-
151150
source_mapping = {config.config_name: config.source_name for config in sources.values()}
152151
model_mapping = {name: config.model_name for name, config in models.items()}
153152
model_mapping.update({name: config.seed_name for name, config in seeds.items()})
@@ -166,9 +165,15 @@ def to_sqlmesh(
166165
for ref in dependencies.refs
167166
}
168167

169-
return load_model(
170-
expressions,
171-
path=self.path,
168+
expressions = d.parse(self.sql)
169+
170+
return create_sql_model(
171+
self.model_name,
172+
expressions[-1],
173+
kind=self.model_kind,
174+
statements=expressions[0:-1],
175+
columns=column_types_to_sqlmesh(self.columns) or None,
176+
column_descriptions_=column_descriptions_to_sqlmesh(self.columns) or None,
172177
python_env=python_env,
173178
depends_on=depends_on,
174179
start=self.start,
@@ -186,7 +191,7 @@ def model_name(self) -> str:
186191
return ".".join(part for part in (schema, self.alias or self.table_name) if part)
187192

188193
@property
189-
def model_kind(self) -> str:
194+
def model_kind(self) -> ModelKind:
190195
"""
191196
Get the sqlmesh ModelKind
192197
@@ -195,20 +200,20 @@ def model_kind(self) -> str:
195200
"""
196201
materialization = self.materialized
197202
if materialization == Materialization.TABLE:
198-
return ModelKindName.FULL.value
203+
return ModelKind(name=ModelKindName.FULL)
199204
if materialization == Materialization.VIEW:
200-
return ModelKindName.VIEW.value
205+
return ModelKind(name=ModelKindName.VIEW)
201206
if materialization == Materialization.INCREMENTAL:
202207
if self.time_column:
203-
return f"{ModelKindName.INCREMENTAL_BY_TIME_RANGE.value} (TIME_COLUMN {self.time_column})"
208+
return IncrementalByTimeRangeKind(time_column=self.time_column)
204209
if self.unique_key:
205-
return f"{ModelKindName.INCREMENTAL_BY_UNIQUE_KEY.value} (UNIQUE_KEY ({','.join(self.unique_key)}))"
210+
return IncrementalByUniqueKeyKind(unique_key=self.unique_key)
206211
raise ConfigError(
207212
"SQLMesh ensures idempotent incremental loads and thus does not support append."
208213
" Add either an unique key (merge) or a time column (insert-overwrite)."
209214
)
210215
if materialization == Materialization.EPHEMERAL:
211-
return ModelKindName.EMBEDDED.value
216+
return ModelKind(name=ModelKindName.EMBEDDED)
212217
raise ConfigError(f"{materialization.value} materialization not supported.")
213218

214219
def _all_dependencies(

sqlmesh/dbt/seed.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88

99
from sqlmesh.core.config.base import UpdateStrategy
1010
from sqlmesh.core.model import Model, SeedKind, create_seed_model
11-
from sqlmesh.dbt.column import ColumnConfig, yaml_to_columns
11+
from sqlmesh.dbt.column import (
12+
ColumnConfig,
13+
column_descriptions_to_sqlmesh,
14+
column_types_to_sqlmesh,
15+
yaml_to_columns,
16+
)
1217
from sqlmesh.dbt.common import GeneralConfig
1318
from sqlmesh.utils.conversions import ensure_bool
1419

@@ -66,6 +71,9 @@ def _validate_grants(cls, v: t.Dict[str, str]) -> t.Dict[str, t.List[str]]:
6671

6772
@validator("columns", pre=True)
6873
def _validate_columns(cls, v: t.Any) -> t.Dict[str, ColumnConfig]:
74+
if not isinstance(v, dict) or all(isinstance(col, ColumnConfig) for col in v.values()):
75+
return v
76+
6977
return yaml_to_columns(v)
7078

7179
_FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {
@@ -82,7 +90,11 @@ def _validate_columns(cls, v: t.Any) -> t.Dict[str, ColumnConfig]:
8290
def to_sqlmesh(self) -> Model:
8391
"""Converts the dbt seed into a SQLMesh model."""
8492
return create_seed_model(
85-
self.seed_name, SeedKind(path=self.path.absolute()), path=self.path
93+
self.seed_name,
94+
SeedKind(path=self.path.absolute()),
95+
path=self.path,
96+
columns=column_types_to_sqlmesh(self.columns) or None,
97+
column_descriptions_=column_descriptions_to_sqlmesh(self.columns) or None,
8698
)
8799

88100
@property

sqlmesh/magics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from sqlmesh.core.console import get_console
1111
from sqlmesh.core.context import Context
12-
from sqlmesh.core.dialect import format_model_expressions, parse_model
12+
from sqlmesh.core.dialect import format_model_expressions, parse
1313
from sqlmesh.core.model import load_model
1414
from sqlmesh.core.test import ModelTestMetadata, get_all_model_tests
1515
from sqlmesh.utils.errors import MagicError, MissingContextException, SQLMeshError
@@ -69,7 +69,7 @@ def model(self, line: str, sql: t.Optional[str] = None) -> None:
6969

7070
if sql:
7171
loaded = load_model(
72-
parse_model(sql, default_dialect=self._context.dialect),
72+
parse(sql, default_dialect=self._context.dialect),
7373
macros=self._context._macros,
7474
path=model._path,
7575
dialect=self._context.dialect,

0 commit comments

Comments
 (0)