Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions sqlalchemy_kusto/dialect_kql.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def visit_select(
from_object = select_stmt.get_final_froms()[0]
if hasattr(from_object, "element"):
query = self._get_most_inner_element(from_object.element)
(main, lets) = self._extract_let_statements(query.text)
main, lets = self._extract_let_statements(query.text)
compiled_query_lines.extend(lets)
compiled_query_lines.append(
f"let {from_object.name} = ({self._convert_schema_in_statement(main)});"
Expand Down Expand Up @@ -362,29 +362,40 @@ def _escape_and_quote_columns(name: str | None, is_alias=False) -> str:
or KustoKqlCompiler._is_number_literal(name)
) and not is_alias:
return name
if name.startswith('"') and name.endswith('"'):
name = name[1:-1]
# First, check if the name is already wrapped in ["ColumnName"] (escaped format)
if name.startswith('["') and name.endswith('"]'):
return name # Return as is if already properly escaped
# Remove surrounding spaces
# Handle mathematical operations (wrap only the column part before operators)
# Find the position of the first operator or space that separates the column name
# Handle arithmetic expressions by recursively escaping both sides
if not is_alias:
for operator in ["/", "+", "-", "*"]:
if operator in name:
# Split the name at the first operator and wrap the left part
parts = name.split(operator, 1)
# Remove quotes if they exist at the edges
col_part = parts[0].strip()
if col_part.startswith('"') and col_part.endswith('"'):
col_part = col_part[1:-1].strip()
col_part = col_part.replace('"', '\\"')
return f'["{col_part}"] {operator} {parts[1].strip()}' # Wrap the column part
# No operators found, just wrap the entire name
# Find operator that's not inside quotes
pos = KustoKqlCompiler._find_operator_outside_quotes(name, operator)
if pos != -1:
left_part = name[:pos].strip()
right_part = name[pos + 1 :].strip()
# Recursively escape both sides
left_escaped = KustoKqlCompiler._escape_and_quote_columns(left_part)
right_escaped = KustoKqlCompiler._escape_and_quote_columns(
right_part
)
return f"{left_escaped} {operator} {right_escaped}"
# No operators found - strip surrounding quotes if present, then wrap
if name.startswith('"') and name.endswith('"'):
name = name[1:-1]
name = name.replace('"', '\\"')
return f'["{name}"]'

@staticmethod
def _find_operator_outside_quotes(text: str, operator: str) -> int:
"""Find position of operator that's not inside quoted strings. Returns -1 if not found."""
in_quotes = False
for i, ch in enumerate(text):
if ch == '"' and (i == 0 or text[i - 1] != "\\"):
in_quotes = not in_quotes
elif ch == operator and not in_quotes:
return i
return -1

@staticmethod
def _sql_to_kql_where(where_clause: str) -> str:
where_clause = where_clause.strip().replace("\n", "")
Expand Down Expand Up @@ -553,7 +564,7 @@ def _is_kql_function(name: str) -> bool:

@staticmethod
def _is_number_literal(s: str) -> bool:
pattern = r"^[0-9]+$"
pattern = r"^\d+(\.\d+)?$"
return bool(re.match(pattern, s))

def _get_most_inner_element(self, clause):
Expand Down
149 changes: 137 additions & 12 deletions tests/unit/test_dialect_kql.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,10 @@ def test_group_by_text():
).replace("\n", "")
# raw query text from query
query_expected = (
'["ActiveUsersLastMonth"]| extend ["ActiveUserMetric"] = ["ActiveUsers"], '
'["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)'
'["ActiveUsersLastMonth"]'
'| summarize by ["EventInfo_Time"] / time(1d)'
'| extend ["ActiveUserMetric"] = ["ActiveUsers"], '
'["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)'
'| project ["EventInfo_Time"], ["ActiveUserMetric"]'
'| order by ["ActiveUserMetric"] desc'
)
Expand Down Expand Up @@ -224,20 +225,19 @@ def test_group_by_text_vaccine_dataset():
query.compile(engine, compile_kwargs={"literal_binds": True})
).replace("\n", "")
query_expected = (
'database("superset").["CovidVaccineData"]| '
'extend ["country_name"] = ["country_name"]| '
'summarize by ["country_name"]| '
'project ["country_name"]| order by ["country_name"] asc'
'database("superset").["CovidVaccineData"]'
'| summarize by ["country_name"]'
'| extend ["country_name"] = ["country_name"]'
'| project ["country_name"]'
'| order by ["country_name"] asc'
)
assert query_compiled == query_expected


def test_is_kql_function():
assert KustoKqlCompiler._is_kql_function(
"""case(Size <= 3, "Small",
assert KustoKqlCompiler._is_kql_function("""case(Size <= 3, "Small",
Size <= 10, "Medium",
"Large")"""
)
"Large")""")
assert KustoKqlCompiler._is_kql_function("""bin(time(16d), 7d)""")
assert KustoKqlCompiler._is_kql_function(
"""iff((EventType in ("Heavy Rain", "Flash Flood", "Flood")), "Rain event", "Not rain event")"""
Expand Down Expand Up @@ -328,8 +328,8 @@ def test_distinct_count_by_text():
# raw query text from query
query_expected = (
'["ActiveUsersLastMonth"]'
'| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)'
'| summarize ["DistinctUsers"] = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)'
'| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)'
'| project ["EventInfo_Time"], ["DistinctUsers"]'
'| order by ["ActiveUserMetric"] desc'
)
Expand All @@ -354,8 +354,8 @@ def test_distinct_count_alt_by_text():
# raw query text from query
query_expected = (
'["ActiveUsersLastMonth"]'
'| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)'
'| summarize ["DistinctUsers"] = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)'
'| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)'
'| project ["EventInfo_Time"], ["DistinctUsers"]'
'| order by ["ActiveUserMetric"] desc'
)
Expand Down Expand Up @@ -549,6 +549,131 @@ def test_match_aggregates(column_name: str, expected_aggregate: str):
assert kql_agg is None


def test_escape_and_quote_columns_with_two_quoted_measures():
"""Test that two quoted measure names with operator are properly escaped.

e.g. "Measure 1" + "Measure 2" --> ["Measure 1"] + ["Measure 2"]
"""
result = KustoKqlCompiler._escape_and_quote_columns('"Measure 1" + "Measure 2"')
assert result == '["Measure 1"] + ["Measure 2"]'


def test_escape_and_quote_columns_preserves_already_bracketed():
"""Test that already-bracketed columns are not double-converted."""
result = KustoKqlCompiler._escape_and_quote_columns('["Measure 1"]')
assert result == '["Measure 1"]'


def test_calculated_measure_with_two_adhoc_measures():
"""Test calculated measure referencing two ad hoc measures.

Measure 3 = "Measure 1" + "Measure 2" should compile to ["Measure 1"] + ["Measure 2"]
"""
measure_3 = literal_column('"Measure 1" + "Measure 2"').label("Measure 3")
query = select([measure_3]).select_from(text("SalesData"))
query_compiled = str(
query.compile(engine, compile_kwargs={"literal_binds": True})
).replace("\n", "")
query_expected = (
'["SalesData"]'
'| extend ["Measure 3"] = ["Measure 1"] + ["Measure 2"]'
'| project ["Measure 3"]'
)
assert query_compiled == query_expected


def test_escape_and_quote_columns_measure_with_constant():
"""Test that measure with operator and constant is properly escaped.

e.g. "Measure 1" * 2 --> ["Measure 1"] * 2
"""
result = KustoKqlCompiler._escape_and_quote_columns('"Measure 1" * 2')
assert result == '["Measure 1"] * 2'


def test_escape_and_quote_columns_measure_with_operator_in_name():
"""Test that measure names containing operators are properly escaped.

e.g. "Measure 1-2" --> ["Measure 1-2"] (not split as ["Measure 1"] - ["2"])
"""
result = KustoKqlCompiler._escape_and_quote_columns('"Measure 1-2"')
assert result == '["Measure 1-2"]'


def test_is_number_literal():
"""Test _is_number_literal correctly identifies numeric literals."""
# Should match: integers and decimals with digits on both sides of decimal
assert KustoKqlCompiler._is_number_literal("5") is True
assert KustoKqlCompiler._is_number_literal("123") is True
assert KustoKqlCompiler._is_number_literal("0") is True
assert KustoKqlCompiler._is_number_literal("0.5") is True
assert KustoKqlCompiler._is_number_literal("5.0") is True
assert KustoKqlCompiler._is_number_literal("123.456") is True

# Should NOT match: trailing decimal, leading decimal, scientific notation, negatives
assert KustoKqlCompiler._is_number_literal("5.") is False
assert KustoKqlCompiler._is_number_literal(".5") is False
assert KustoKqlCompiler._is_number_literal("-5") is False
assert KustoKqlCompiler._is_number_literal("-0.5") is False
assert KustoKqlCompiler._is_number_literal("1e10") is False
assert KustoKqlCompiler._is_number_literal("1.5e-3") is False

# Should NOT match: non-numeric strings
assert KustoKqlCompiler._is_number_literal("abc") is False
assert KustoKqlCompiler._is_number_literal("Measure 1") is False
assert KustoKqlCompiler._is_number_literal("") is False


def test_calculated_measure_with_adhoc_measure_and_constant():
"""Test calculated measure with an ad hoc measure and a constant.

Measure 1 = count(*), Measure 2 = "Measure 1" * 2
Measure 2 should compile to ["Measure 1"] * 2
"""
measure_1 = literal_column("count(*)").label("Measure 1")
measure_2 = literal_column('"Measure 1" * 2').label("Measure 2")
query = select([measure_1, measure_2]).select_from(text("SalesData"))
query_compiled = str(
query.compile(engine, compile_kwargs={"literal_binds": True})
).replace("\n", "")
query_expected = (
'["SalesData"]'
'| summarize ["Measure 1"] = count() '
'| extend ["Measure 2"] = ["Measure 1"] * 2'
'| project ["Measure 1"], ["Measure 2"]'
)
assert query_compiled == query_expected


def test_calculated_measure_with_two_adhoc_measures_and_aggregates():
"""Test calculated measure referencing two ad hoc measures with aggregates.

Measure 1 = count(*), Measure 2 = count(*)
Measure 3 = "Measure 1" + "Measure 2" should compile to ["Measure 1"] + ["Measure 2"]
"""
measure_1 = literal_column("count(*)").label("Measure 1")
measure_2 = literal_column("count(*)").label("Measure 2")
measure_3 = literal_column('"Measure 1" + "Measure 2"').label("Measure 3")
query = select([measure_1, measure_2, measure_3]).select_from(text("SalesData"))
query_compiled = str(
query.compile(engine, compile_kwargs={"literal_binds": True})
).replace("\n", "")
# Summarize columns come from a set so order may vary
query_expected_1 = (
'["SalesData"]'
'| summarize ["Measure 1"] = count(), ["Measure 2"] = count() '
'| extend ["Measure 3"] = ["Measure 1"] + ["Measure 2"]'
'| project ["Measure 1"], ["Measure 2"], ["Measure 3"]'
)
query_expected_2 = (
'["SalesData"]'
'| summarize ["Measure 2"] = count(), ["Measure 1"] = count() '
'| extend ["Measure 3"] = ["Measure 1"] + ["Measure 2"]'
'| project ["Measure 1"], ["Measure 2"], ["Measure 3"]'
)
assert query_compiled in (query_expected_1, query_expected_2)


@pytest.mark.parametrize(
("query_table_name", "expected_table_name"),
[
Expand Down