Skip to content

Commit 04bf185

Browse files
authored
Fix jinja block start/end bug (#400)
* Fix jinja block start/end bug * Add some checks for additional robustness * Fix annotate_types, update tests to reflect optimizer changes * Remove tokenizer overrides in dialect * Add pushdown_projections to the optimization pipeline
1 parent 3763dec commit 04bf185

5 files changed

Lines changed: 52 additions & 17 deletions

File tree

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
"requests",
3939
"rich",
4040
"ruamel.yaml",
41-
"sqlglot==11.1.2",
41+
"sqlglot>=11.2.0",
4242
],
4343
extras_require={
4444
"dev": [

sqlmesh/core/dialect.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,14 @@ def parse(sql: str, default_dialect: str | None = None) -> t.List[exp.Expression
408408
if i < total - 1:
409409
chunks.append(([], False))
410410
else:
411-
if token.token_type == TokenType.BLOCK_START or (
412-
token.token_type == TokenType.STRING and JINJA_PATTERN.search(token.text)
411+
if (
412+
token.token_type == TokenType.BLOCK_START
413+
or (
414+
i < total - 1
415+
and token.token_type == TokenType.L_BRACE
416+
and tokens[i + 1].token_type == TokenType.L_BRACE
417+
)
418+
or (token.token_type == TokenType.STRING and JINJA_PATTERN.search(token.text))
413419
):
414420
chunks[-1] = (chunks[-1][0], True)
415421
chunks[-1][0].append(token)

sqlmesh/core/renderer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sqlglot.optimizer import optimize
1111
from sqlglot.optimizer.annotate_types import annotate_types
1212
from sqlglot.optimizer.expand_laterals import expand_laterals
13+
from sqlglot.optimizer.pushdown_projections import pushdown_projections
1314
from sqlglot.optimizer.qualify_columns import qualify_columns
1415
from sqlglot.optimizer.qualify_tables import qualify_tables
1516
from sqlglot.optimizer.simplify import simplify
@@ -30,6 +31,8 @@
3031
qualify_tables,
3132
qualify_columns,
3233
expand_laterals,
34+
pushdown_projections,
35+
annotate_types,
3336
)
3437

3538

@@ -151,8 +154,6 @@ def render(
151154
except SqlglotError as ex:
152155
raise_config_error(f"Invalid model query. {ex}", self._path)
153156

154-
self._query_cache[cache_key] = annotate_types(self._query_cache[cache_key])
155-
156157
query = self._query_cache[cache_key]
157158

158159
if expand:

sqlmesh/utils/jinja.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,34 @@ def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]:
2727
A dictionary of macro name to macro definition.
2828
"""
2929
self.reset()
30-
self.sql = jinja or ""
30+
self.sql = jinja
3131
self._tokens = Dialect.get_or_raise(dialect)().tokenizer.tokenize(jinja)
3232
self._index = -1
3333
self._advance()
3434

3535
macros: t.Dict[str, MacroInfo] = {}
3636

3737
while self._curr:
38-
if self._curr.token_type == TokenType.BLOCK_START:
38+
if self._at_block_start():
39+
if self._prev and self._prev.token_type == TokenType.L_BRACE:
40+
self._advance()
3941
macro_start = self._curr
4042
elif self._tag == "MACRO" and self._next:
4143
name = self._next.text
42-
while self._curr and self._curr.token_type != TokenType.BLOCK_END:
44+
while self._curr and not self._at_block_end():
4345
self._advance()
46+
else:
47+
if self._prev and self._prev.token_type == TokenType.R_BRACE:
48+
self._advance()
49+
4450
body_start = self._next
4551

4652
while self._curr and self._tag != "ENDMACRO":
47-
if self._curr.token_type == TokenType.BLOCK_START:
53+
if self._at_block_start():
4854
body_end = self._prev
55+
if self._prev and self._prev.token_type == TokenType.L_BRACE:
56+
self._advance()
57+
4958
self._advance()
5059

5160
calls = capture_jinja(self._find_sql(body_start, body_end)).calls
@@ -55,11 +64,30 @@ def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]:
5564

5665
return macros
5766

67+
def _at_block_start(self) -> bool:
68+
return self._curr.token_type == TokenType.BLOCK_START or self._match_pair(
69+
TokenType.L_BRACE, TokenType.L_BRACE, advance=False
70+
)
71+
72+
def _at_block_end(self) -> bool:
73+
return self._curr.token_type == TokenType.BLOCK_END or self._match_pair(
74+
TokenType.R_BRACE, TokenType.R_BRACE, advance=False
75+
)
76+
5877
def _advance(self, times: int = 1) -> None:
5978
super()._advance(times)
6079
self._tag = (
6180
self._curr.text.upper()
62-
if self._curr and self._prev and self._prev.token_type == TokenType.BLOCK_START
81+
if self._curr
82+
and self._prev
83+
and (
84+
self._prev.token_type == TokenType.BLOCK_START
85+
or (
86+
self._index > 1
87+
and self._tokens[self._index - 1].token_type == TokenType.L_BRACE
88+
and self._tokens[self._index - 2].token_type == TokenType.L_BRACE
89+
)
90+
)
6391
else ""
6492
)
6593

tests/core/test_audit.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_no_query():
140140

141141

142142
def test_macro(model: Model):
143-
expected_query = "SELECT * FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') WHERE a IS NULL"
143+
expected_query = "SELECT * FROM (SELECT test_model.a AS a FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 WHERE _q_0.a IS NULL"
144144

145145
audit = Audit(
146146
name="test_audit",
@@ -163,7 +163,7 @@ def test_not_null_audit(model: Model):
163163
)
164164
assert (
165165
rendered_query_a.sql()
166-
== "SELECT * FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') WHERE a IS NULL"
166+
== "SELECT * FROM (SELECT test_model.a AS a FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 WHERE _q_0.a IS NULL"
167167
)
168168

169169
rendered_query_a_and_b = builtin.not_null_audit.render_query(
@@ -172,23 +172,23 @@ def test_not_null_audit(model: Model):
172172
)
173173
assert (
174174
rendered_query_a_and_b.sql()
175-
== "SELECT * FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') WHERE a IS NULL OR b IS NULL"
175+
== "SELECT * FROM (SELECT test_model.a AS a, test_model.b AS b FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 WHERE _q_0.a IS NULL OR _q_0.b IS NULL"
176176
)
177177

178178

179179
def test_unique_values_audit(model: Model):
180180
rendered_query_a = builtin.unique_values_audit.render_query(model, columns=[exp.to_column("a")])
181181
assert (
182182
rendered_query_a.sql()
183-
== "SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY 1) AS a_rank FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01')) WHERE a_rank > 1"
183+
== "SELECT _q_1.a_rank AS a_rank FROM (SELECT ROW_NUMBER() OVER (PARTITION BY _q_0.a ORDER BY 1) AS a_rank FROM (SELECT test_model.a AS a FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0) AS _q_1 WHERE _q_1.a_rank > 1"
184184
)
185185

186186
rendered_query_a_and_b = builtin.unique_values_audit.render_query(
187187
model, columns=[exp.to_column("a"), exp.to_column("b")]
188188
)
189189
assert (
190190
rendered_query_a_and_b.sql()
191-
== "SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY 1) AS a_rank, ROW_NUMBER() OVER (PARTITION BY b ORDER BY 1) AS b_rank FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01')) WHERE a_rank > 1 OR b_rank > 1"
191+
== "SELECT _q_1.a_rank AS a_rank, _q_1.b_rank AS b_rank FROM (SELECT ROW_NUMBER() OVER (PARTITION BY _q_0.a ORDER BY 1) AS a_rank, ROW_NUMBER() OVER (PARTITION BY _q_0.b ORDER BY 1) AS b_rank FROM (SELECT test_model.a AS a, test_model.b AS b FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0) AS _q_1 WHERE _q_1.a_rank > 1 OR _q_1.b_rank > 1"
192192
)
193193

194194

@@ -200,7 +200,7 @@ def test_accepted_values_audit(model: Model):
200200
)
201201
assert (
202202
rendered_query.sql()
203-
== "SELECT * FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') WHERE NOT a IN ('value_a', 'value_b')"
203+
== "SELECT * FROM (SELECT test_model.a AS a FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 WHERE NOT _q_0.a IN ('value_a', 'value_b')"
204204
)
205205

206206

@@ -211,5 +211,5 @@ def test_number_of_rows_audit(model: Model):
211211
)
212212
assert (
213213
rendered_query.sql()
214-
== "SELECT 1 FROM (SELECT * FROM db.test_model WHERE ds <= '1970-01-01' AND ds >= '1970-01-01') HAVING COUNT(*) <= 0 LIMIT 0 + 1"
214+
== """SELECT 1 AS "1" FROM (SELECT 1 AS _ FROM db.test_model AS test_model WHERE test_model.ds <= '1970-01-01' AND test_model.ds >= '1970-01-01') AS _q_0 HAVING COUNT(*) <= 0 LIMIT 0 + 1"""
215215
)

0 commit comments

Comments
 (0)