Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 119 additions & 70 deletions sqlalchemy_kusto/dialect_kql.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,60 @@
"variancep",
}
AGGREGATE_PATTERN = r"(\w+)\s*\(\s*(DISTINCT|distinct\s*)?\(?\s*(\*|\[?\"?\'?\w+\"?\]?)\s*(,.+)*\)?\s*\)"
# Pre-compiled regex for aggregate function matching (performance optimization).
# Compiled once at module load to avoid recompiling on every call, which significantly
# improves performance for query-heavy workloads.
KQL_AGG_PATTERN = re.compile(
r"\b(" + "|".join(kql_aggregates) + r")\s*\(", re.IGNORECASE
)


class _ParseState:
"""Tracks parsing state while scanning through text."""

__slots__ = ("in_double_quote", "in_single_quote", "in_bracket", "paren_depth")

def __init__(self):
self.in_double_quote = False
self.in_single_quote = False
self.in_bracket = False
self.paren_depth = 0

def update(self, ch: str, prev_ch: str | None) -> None:
"""Update state based on current and previous character."""
# Handle quotes (only if not escaped and not in conflicting context)
if (
ch == '"'
and prev_ch != "\\"
and not self.in_single_quote
and not self.in_bracket
):
self.in_double_quote = not self.in_double_quote
elif (
ch == "'"
and prev_ch != "\\"
and not self.in_double_quote
and not self.in_bracket
):
self.in_single_quote = not self.in_single_quote
# Handle brackets and parens (only if not in quotes)
elif not self.in_double_quote and not self.in_single_quote:
if ch == "[":
self.in_bracket = True
elif ch == "]":
self.in_bracket = False
elif not self.in_bracket:
if ch == "(":
self.paren_depth += 1
elif ch == ")":
self.paren_depth -= 1

@property
def in_quotes_or_brackets(self) -> bool:
"""Check if currently inside quotes or brackets."""
return self.in_double_quote or self.in_single_quote or self.in_bracket


class UniversalSet:
def __contains__(self, item):
return True
Expand All @@ -92,76 +141,77 @@ class KustoKqlCompiler(compiler.SQLCompiler):
visit_sequence = None
sort_with_clause_parts = 2

@staticmethod
def _find_top_level_operator(text: str, operator: str) -> int:
"""Find position of operator at depth 0 (not inside quotes, brackets, or parens).

Args:
text: The string to search in
operator: The single-character operator to find (e.g., '+', '-', '*', '/')

Returns:
The position of the operator at depth 0, or -1 if not found.
Returns -1 when the operator only appears inside quotes, brackets, or nested parens.
"""
state = _ParseState()
for i, ch in enumerate(text):
if (
ch == operator
and state.paren_depth == 0
and not state.in_quotes_or_brackets
):
return i
state.update(ch, text[i - 1] if i > 0 else None)
return -1

@staticmethod
def _is_inside_quotes_or_brackets(text: str, pos: int) -> bool:
"""Check if a position in text is inside quotes or brackets."""
if pos >= len(text):
return False

in_double_quote = False
in_single_quote = False
in_bracket = False

state = _ParseState()
for i in range(pos):
ch = text[i]
prev_ch = text[i - 1] if i > 0 else None
if ch == '"' and prev_ch != "\\" and not in_single_quote and not in_bracket:
in_double_quote = not in_double_quote
elif (
ch == "'" and prev_ch != "\\" and not in_double_quote and not in_bracket
):
in_single_quote = not in_single_quote
elif not in_double_quote and not in_single_quote:
if ch == "[":
in_bracket = True
elif ch == "]":
in_bracket = False

return in_double_quote or in_single_quote or in_bracket
state.update(text[i], text[i - 1] if i > 0 else None)
return state.in_quotes_or_brackets

@staticmethod
def _find_matching_paren(text: str, start_pos: int) -> int:
"""Find the matching closing parenthesis for an opening paren at start_pos."""
if start_pos >= len(text) or text[start_pos] != "(":
return -1

in_double_quote = False
in_single_quote = False
in_bracket = False
paren_depth = 1
state = _ParseState()
state.paren_depth = 1 # Start with depth 1 since we're at opening paren

for i in range(start_pos + 1, len(text)):
ch = text[i]
prev_ch = text[i - 1] if i > 0 else None

if ch == '"' and prev_ch != "\\" and not in_single_quote and not in_bracket:
in_double_quote = not in_double_quote
elif (
ch == "'" and prev_ch != "\\" and not in_double_quote and not in_bracket
):
in_single_quote = not in_single_quote
elif not in_double_quote and not in_single_quote:
if ch == "[":
in_bracket = True
elif ch == "]":
in_bracket = False
elif not in_bracket:
if ch == "(":
paren_depth += 1
elif ch == ")":
paren_depth -= 1
if paren_depth == 0:
return i
state.update(ch, text[i - 1] if i > 0 else None)
if state.paren_depth == 0:
return i
return -1

@staticmethod
def _has_operators_outside_quotes(expr: str) -> bool:
"""Check if expression has arithmetic operators outside of quoted strings."""
for operator in "+-*/":
pos = KustoKqlCompiler._find_operator_outside_quotes(expr, operator)
if pos != -1:
return True
return False
"""Check if expression has arithmetic operators outside of quoted strings and brackets."""
return any(
KustoKqlCompiler._find_top_level_operator(expr, op) != -1 for op in "+-*/"
)

@staticmethod
def _count_outer_parens(text: str) -> tuple[int, str]:
"""Count and strip outer parentheses from text. Returns (count, stripped_text)."""
text = text.strip()
count = 0
while len(text) > 1 and text[0] == "(" and text[-1] == ")":
depth = 0
for ch in text[:-1]: # Scan all but last char
depth += (ch == "(") - (ch == ")")
if depth == 0:
return count, text # First '(' closed before end
count += 1
text = text[1:-1].strip()
return count, text

@staticmethod
def _extract_and_replace_aggregates(
Expand Down Expand Up @@ -543,37 +593,36 @@ def _escape_and_quote_columns(name: str | None, is_alias=False) -> str:
# First, check if the name is already wrapped in ["ColumnName"] (escaped format)
if name.startswith('["') and name.endswith('"]'):
return name # Return as is if already properly escaped
# Handle arithmetic expressions by recursively escaping both sides
# Handle arithmetic expressions by recursively processing operands
if not is_alias:
outer_paren_count, inner = KustoKqlCompiler._count_outer_parens(name)
for operator in ["/", "+", "-", "*"]:
# Find operator that's not inside quotes
pos = KustoKqlCompiler._find_operator_outside_quotes(name, operator)
pos = KustoKqlCompiler._find_top_level_operator(inner, operator)
if pos != -1:
left_part = name[:pos].strip()
right_part = name[pos + 1 :].strip()
# Recursively escape both sides
left_escaped = KustoKqlCompiler._escape_and_quote_columns(left_part)
right_escaped = KustoKqlCompiler._escape_and_quote_columns(
right_part
left = KustoKqlCompiler._escape_and_quote_columns(
inner[:pos].strip()
)
right = KustoKqlCompiler._escape_and_quote_columns(
inner[pos + 1 :].strip()
)
return f"{left_escaped} {operator} {right_escaped}"
return (
"(" * outer_paren_count
+ left
+ " "
+ operator
+ " "
+ right
+ ")" * outer_paren_count
)
# No operators - just process inner content (don't re-add unnecessary parens)
if outer_paren_count > 0:
return KustoKqlCompiler._escape_and_quote_columns(inner)
# No operators found - strip surrounding quotes if present, then wrap
if name.startswith('"') and name.endswith('"'):
name = name[1:-1]
name = name.replace('"', '\\"')
return f'["{name}"]'

@staticmethod
def _find_operator_outside_quotes(text: str, operator: str) -> int:
"""Find position of operator that's not inside quoted strings. Returns -1 if not found."""
in_quotes = False
for i, ch in enumerate(text):
if ch == '"' and (i == 0 or text[i - 1] != "\\"):
in_quotes = not in_quotes
elif not in_quotes and ch == operator:
return i
return -1

@staticmethod
def _sql_to_kql_where(where_clause: str) -> str:
where_clause = where_clause.strip().replace("\n", "")
Expand Down
88 changes: 88 additions & 0 deletions tests/integration/test_dialect_kql.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,94 @@ def test_date_bin_ops(test_label, group_fn, temp_table_name, expected, compare_d
assert actual_result == expected_records


def test_parentheses_preserved_in_expression(temp_table_name):
table = Table(
temp_table_name,
metadata,
Column("Text", String),
)
measure = literal_column('(("Text"))').label("Measure 1")
query = (
session.query(measure)
.select_from(table)
.order_by(literal_column('"Measure 1"'))
)
query_compiled = str(query.statement.compile(kql_engine)).replace("\n", "")

assert "((" in query_compiled
assert "))" in query_compiled
assert '["Text"]' in query_compiled

with kql_engine.connect() as connection:
result = connection.execute(text(query_compiled))
values = [row[0] for row in result.fetchall()]
expected_count = 9
assert len(values) == expected_count
assert set(values) == {"value_0", "value_1"}


def test_quoted_identifier_converted_to_kql(temp_table_name):
table = Table(
temp_table_name,
metadata,
Column("Text", String),
)
quoted_text = literal_column("Text").label("QuotedText")
query = session.query(quoted_text).select_from(table).order_by(text("QuotedText"))
query_compiled = str(query.statement.compile(kql_engine)).replace("\n", "")

assert '["Text"]' in query_compiled

with kql_engine.connect() as connection:
result = connection.execute(text(query_compiled))
values = [row[0] for row in result.fetchall()]
expected_count = 9
assert len(values) == expected_count
assert set(values) == {"value_0", "value_1"}


def test_uppercase_functions_lowercased(temp_table_name):
table = Table(
temp_table_name,
metadata,
)
query = session.query(func.COUNT(text("Id")).label("tag_count")).select_from(table)
query_compiled = str(query.statement.compile(kql_engine)).replace("\n", "")

assert "count(" in query_compiled
assert "COUNT(" not in query_compiled

with kql_engine.connect() as connection:
result = connection.execute(text(query_compiled))
assert {row[0] for row in result.fetchall()} == {9}


def test_extend_after_summarize_for_calculated_measures(temp_table_name):
table = Table(
temp_table_name,
metadata,
)
measure_1 = func.COUNT(text("Id")).label("Measure 1")
measure_2 = literal_column('"Measure 1" * 2').label("Measure 2")
query = session.query(measure_1, measure_2).select_from(table)
query_compiled = str(query.statement.compile(kql_engine)).replace("\n", "")

summarize_index = query_compiled.find("| summarize")
extend_index = query_compiled.find("| extend")
assert summarize_index != -1
assert extend_index != -1
assert summarize_index < extend_index

with kql_engine.connect() as connection:
result = connection.execute(text(query_compiled))
row = result.fetchone()
assert row is not None
expected_count = 9
expected_double_count = 18
assert int(row[0]) == expected_count
assert int(row[1]) == expected_double_count


def get_kcsb():
return (
KustoConnectionStringBuilder.with_az_cli_authentication(KUSTO_URL)
Expand Down
Loading