diff --git a/sqlalchemy_kusto/dialect_kql.py b/sqlalchemy_kusto/dialect_kql.py index 0cc0bd0..c10926a 100644 --- a/sqlalchemy_kusto/dialect_kql.py +++ b/sqlalchemy_kusto/dialect_kql.py @@ -107,7 +107,7 @@ def visit_select( from_object = select_stmt.get_final_froms()[0] if hasattr(from_object, "element"): query = self._get_most_inner_element(from_object.element) - (main, lets) = self._extract_let_statements(query.text) + main, lets = self._extract_let_statements(query.text) compiled_query_lines.extend(lets) compiled_query_lines.append( f"let {from_object.name} = ({self._convert_schema_in_statement(main)});" @@ -362,29 +362,40 @@ def _escape_and_quote_columns(name: str | None, is_alias=False) -> str: or KustoKqlCompiler._is_number_literal(name) ) and not is_alias: return name - if name.startswith('"') and name.endswith('"'): - name = name[1:-1] # First, check if the name is already wrapped in ["ColumnName"] (escaped format) if name.startswith('["') and name.endswith('"]'): return name # Return as is if already properly escaped - # Remove surrounding spaces - # Handle mathematical operations (wrap only the column part before operators) - # Find the position of the first operator or space that separates the column name + # Handle arithmetic expressions by recursively escaping both sides if not is_alias: for operator in ["/", "+", "-", "*"]: - if operator in name: - # Split the name at the first operator and wrap the left part - parts = name.split(operator, 1) - # Remove quotes if they exist at the edges - col_part = parts[0].strip() - if col_part.startswith('"') and col_part.endswith('"'): - col_part = col_part[1:-1].strip() - col_part = col_part.replace('"', '\\"') - return f'["{col_part}"] {operator} {parts[1].strip()}' # Wrap the column part - # No operators found, just wrap the entire name + # Find operator that's not inside quotes + pos = KustoKqlCompiler._find_operator_outside_quotes(name, operator) + if pos != -1: + left_part = name[:pos].strip() + right_part = name[pos + 1 :].strip() + # Recursively escape both sides + left_escaped = KustoKqlCompiler._escape_and_quote_columns(left_part) + right_escaped = KustoKqlCompiler._escape_and_quote_columns( + right_part + ) + return f"{left_escaped} {operator} {right_escaped}" + # No operators found - strip surrounding quotes if present, then wrap + if name.startswith('"') and name.endswith('"'): + name = name[1:-1] name = name.replace('"', '\\"') return f'["{name}"]' + @staticmethod + def _find_operator_outside_quotes(text: str, operator: str) -> int: + """Find position of operator that's not inside quoted strings. Returns -1 if not found.""" + in_quotes = False + for i, ch in enumerate(text): + if ch == '"' and (i == 0 or text[i - 1] != "\\"): + in_quotes = not in_quotes + elif ch == operator and not in_quotes: + return i + return -1 + @staticmethod def _sql_to_kql_where(where_clause: str) -> str: where_clause = where_clause.strip().replace("\n", "") @@ -553,7 +564,7 @@ def _is_kql_function(name: str) -> bool: @staticmethod def _is_number_literal(s: str) -> bool: - pattern = r"^[0-9]+$" + pattern = r"^\d+(\.\d+)?$" return bool(re.match(pattern, s)) def _get_most_inner_element(self, clause): diff --git a/tests/unit/test_dialect_kql.py b/tests/unit/test_dialect_kql.py index 04b8fc0..d3e4917 100644 --- a/tests/unit/test_dialect_kql.py +++ b/tests/unit/test_dialect_kql.py @@ -175,9 +175,10 @@ def test_group_by_text(): ).replace("\n", "") # raw query text from query query_expected = ( - '["ActiveUsersLastMonth"]| extend ["ActiveUserMetric"] = ["ActiveUsers"], ' - '["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' + '["ActiveUsersLastMonth"]' '| summarize by ["EventInfo_Time"] / time(1d)' + '| extend ["ActiveUserMetric"] = ["ActiveUsers"], ' + '["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' '| project ["EventInfo_Time"], ["ActiveUserMetric"]' '| order by ["ActiveUserMetric"] desc' ) @@ -224,20 +225,19 @@ def test_group_by_text_vaccine_dataset(): query.compile(engine, compile_kwargs={"literal_binds": True}) ).replace("\n", "") query_expected = ( - 'database("superset").["CovidVaccineData"]| ' - 'extend ["country_name"] = ["country_name"]| ' - 'summarize by ["country_name"]| ' - 'project ["country_name"]| order by ["country_name"] asc' + 'database("superset").["CovidVaccineData"]' + '| summarize by ["country_name"]' + '| extend ["country_name"] = ["country_name"]' + '| project ["country_name"]' + '| order by ["country_name"] asc' ) assert query_compiled == query_expected def test_is_kql_function(): - assert KustoKqlCompiler._is_kql_function( - """case(Size <= 3, "Small", + assert KustoKqlCompiler._is_kql_function("""case(Size <= 3, "Small", Size <= 10, "Medium", - "Large")""" - ) + "Large")""") assert KustoKqlCompiler._is_kql_function("""bin(time(16d), 7d)""") assert KustoKqlCompiler._is_kql_function( """iff((EventType in ("Heavy Rain", "Flash Flood", "Flood")), "Rain event", "Not rain event")""" @@ -328,8 +328,8 @@ def test_distinct_count_by_text(): # raw query text from query query_expected = ( '["ActiveUsersLastMonth"]' - '| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' '| summarize ["DistinctUsers"] = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)' + '| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' '| project ["EventInfo_Time"], ["DistinctUsers"]' '| order by ["ActiveUserMetric"] desc' ) @@ -354,8 +354,8 @@ def test_distinct_count_alt_by_text(): # raw query text from query query_expected = ( '["ActiveUsersLastMonth"]' - '| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' '| summarize ["DistinctUsers"] = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)' + '| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' '| project ["EventInfo_Time"], ["DistinctUsers"]' '| order by ["ActiveUserMetric"] desc' ) @@ -549,6 +549,131 @@ def test_match_aggregates(column_name: str, expected_aggregate: str): assert kql_agg is None +def test_escape_and_quote_columns_with_two_quoted_measures(): + """Test that two quoted measure names with operator are properly escaped. + + e.g. "Measure 1" + "Measure 2" --> ["Measure 1"] + ["Measure 2"] + """ + result = KustoKqlCompiler._escape_and_quote_columns('"Measure 1" + "Measure 2"') + assert result == '["Measure 1"] + ["Measure 2"]' + + +def test_escape_and_quote_columns_preserves_already_bracketed(): + """Test that already-bracketed columns are not double-converted.""" + result = KustoKqlCompiler._escape_and_quote_columns('["Measure 1"]') + assert result == '["Measure 1"]' + + +def test_calculated_measure_with_two_adhoc_measures(): + """Test calculated measure referencing two ad hoc measures. + + Measure 3 = "Measure 1" + "Measure 2" should compile to ["Measure 1"] + ["Measure 2"] + """ + measure_3 = literal_column('"Measure 1" + "Measure 2"').label("Measure 3") + query = select([measure_3]).select_from(text("SalesData")) + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") + query_expected = ( + '["SalesData"]' + '| extend ["Measure 3"] = ["Measure 1"] + ["Measure 2"]' + '| project ["Measure 3"]' + ) + assert query_compiled == query_expected + + +def test_escape_and_quote_columns_measure_with_constant(): + """Test that measure with operator and constant is properly escaped. + + e.g. "Measure 1" * 2 --> ["Measure 1"] * 2 + """ + result = KustoKqlCompiler._escape_and_quote_columns('"Measure 1" * 2') + assert result == '["Measure 1"] * 2' + + +def test_escape_and_quote_columns_measure_with_operator_in_name(): + """Test that measure names containing operators are properly escaped. + + e.g. "Measure 1-2" --> ["Measure 1-2"] (not split as ["Measure 1"] - ["2"]) + """ + result = KustoKqlCompiler._escape_and_quote_columns('"Measure 1-2"') + assert result == '["Measure 1-2"]' + + +def test_is_number_literal(): + """Test _is_number_literal correctly identifies numeric literals.""" + # Should match: integers and decimals with digits on both sides of decimal + assert KustoKqlCompiler._is_number_literal("5") is True + assert KustoKqlCompiler._is_number_literal("123") is True + assert KustoKqlCompiler._is_number_literal("0") is True + assert KustoKqlCompiler._is_number_literal("0.5") is True + assert KustoKqlCompiler._is_number_literal("5.0") is True + assert KustoKqlCompiler._is_number_literal("123.456") is True + + # Should NOT match: trailing decimal, leading decimal, scientific notation, negatives + assert KustoKqlCompiler._is_number_literal("5.") is False + assert KustoKqlCompiler._is_number_literal(".5") is False + assert KustoKqlCompiler._is_number_literal("-5") is False + assert KustoKqlCompiler._is_number_literal("-0.5") is False + assert KustoKqlCompiler._is_number_literal("1e10") is False + assert KustoKqlCompiler._is_number_literal("1.5e-3") is False + + # Should NOT match: non-numeric strings + assert KustoKqlCompiler._is_number_literal("abc") is False + assert KustoKqlCompiler._is_number_literal("Measure 1") is False + assert KustoKqlCompiler._is_number_literal("") is False + + +def test_calculated_measure_with_adhoc_measure_and_constant(): + """Test calculated measure with an ad hoc measure and a constant. + + Measure 1 = count(*), Measure 2 = "Measure 1" * 2 + Measure 2 should compile to ["Measure 1"] * 2 + """ + measure_1 = literal_column("count(*)").label("Measure 1") + measure_2 = literal_column('"Measure 1" * 2').label("Measure 2") + query = select([measure_1, measure_2]).select_from(text("SalesData")) + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") + query_expected = ( + '["SalesData"]' + '| summarize ["Measure 1"] = count() ' + '| extend ["Measure 2"] = ["Measure 1"] * 2' + '| project ["Measure 1"], ["Measure 2"]' + ) + assert query_compiled == query_expected + + +def test_calculated_measure_with_two_adhoc_measures_and_aggregates(): + """Test calculated measure referencing two ad hoc measures with aggregates. + + Measure 1 = count(*), Measure 2 = count(*) + Measure 3 = "Measure 1" + "Measure 2" should compile to ["Measure 1"] + ["Measure 2"] + """ + measure_1 = literal_column("count(*)").label("Measure 1") + measure_2 = literal_column("count(*)").label("Measure 2") + measure_3 = literal_column('"Measure 1" + "Measure 2"').label("Measure 3") + query = select([measure_1, measure_2, measure_3]).select_from(text("SalesData")) + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") + # Summarize columns come from a set so order may vary + query_expected_1 = ( + '["SalesData"]' + '| summarize ["Measure 1"] = count(), ["Measure 2"] = count() ' + '| extend ["Measure 3"] = ["Measure 1"] + ["Measure 2"]' + '| project ["Measure 1"], ["Measure 2"], ["Measure 3"]' + ) + query_expected_2 = ( + '["SalesData"]' + '| summarize ["Measure 2"] = count(), ["Measure 1"] = count() ' + '| extend ["Measure 3"] = ["Measure 1"] + ["Measure 2"]' + '| project ["Measure 1"], ["Measure 2"], ["Measure 3"]' + ) + assert query_compiled in (query_expected_1, query_expected_2) + + @pytest.mark.parametrize( ("query_table_name", "expected_table_name"), [