From d6cc4e6427c068c3bec5425424f20e012b5ed023 Mon Sep 17 00:00:00 2001 From: Alison Gim Date: Tue, 10 Feb 2026 05:05:41 +0000 Subject: [PATCH 1/3] Add parentheses support for calculated measures - Add _ParseState class for unified parsing state tracking - Add _find_top_level_operator to find operators at depth 0 (outside quotes, brackets, and parens) - Add _count_outer_parens to handle nested parentheses in expressions - Add _wrap_column_refs_in_parens for arithmetic precedence - Update _escape_and_quote_columns to handle parenthesized expressions - Update _has_operators_outside_quotes to use new top-level operator detection - Refactor _is_inside_quotes_or_brackets and _find_matching_paren to use _ParseState - Remove old _find_operator_outside_quotes (replaced by _find_top_level_operator) - Update tests to reflect correct behavior for count(*) - no longer seen as expression with operators --- sqlalchemy_kusto/dialect_kql.py | 202 +++++++++++++++++++++----------- tests/unit/test_dialect_kql.py | 19 ++- 2 files changed, 140 insertions(+), 81 deletions(-) diff --git a/sqlalchemy_kusto/dialect_kql.py b/sqlalchemy_kusto/dialect_kql.py index ea845dc..a8e7652 100644 --- a/sqlalchemy_kusto/dialect_kql.py +++ b/sqlalchemy_kusto/dialect_kql.py @@ -66,11 +66,60 @@ "variancep", } AGGREGATE_PATTERN = r"(\w+)\s*\(\s*(DISTINCT|distinct\s*)?\(?\s*(\*|\[?\"?\'?\w+\"?\]?)\s*(,.+)*\)?\s*\)" +# Pre-compiled regex for aggregate function matching (performance optimization). +# Compiled once at module load to avoid recompiling on every call, which significantly +# improves performance for query-heavy workloads. KQL_AGG_PATTERN = re.compile( r"\b(" + "|".join(kql_aggregates) + r")\s*\(", re.IGNORECASE ) +class _ParseState: + """Tracks parsing state while scanning through text.""" + + __slots__ = ("in_double_quote", "in_single_quote", "in_bracket", "paren_depth") + + def __init__(self): + self.in_double_quote = False + self.in_single_quote = False + self.in_bracket = False + self.paren_depth = 0 + + def update(self, ch: str, prev_ch: str | None) -> None: + """Update state based on current and previous character.""" + # Handle quotes (only if not escaped and not in conflicting context) + if ( + ch == '"' + and prev_ch != "\\" + and not self.in_single_quote + and not self.in_bracket + ): + self.in_double_quote = not self.in_double_quote + elif ( + ch == "'" + and prev_ch != "\\" + and not self.in_double_quote + and not self.in_bracket + ): + self.in_single_quote = not self.in_single_quote + # Handle brackets and parens (only if not in quotes) + elif not self.in_double_quote and not self.in_single_quote: + if ch == "[": + self.in_bracket = True + elif ch == "]": + self.in_bracket = False + elif not self.in_bracket: + if ch == "(": + self.paren_depth += 1 + elif ch == ")": + self.paren_depth -= 1 + + @property + def in_quotes_or_brackets(self) -> bool: + """Check if currently inside quotes or brackets.""" + return self.in_double_quote or self.in_single_quote or self.in_bracket + + class UniversalSet: def __contains__(self, item): return True @@ -92,32 +141,39 @@ class KustoKqlCompiler(compiler.SQLCompiler): visit_sequence = None sort_with_clause_parts = 2 + @staticmethod + def _find_top_level_operator(text: str, operator: str) -> int: + """Find position of operator at depth 0 (not inside quotes, brackets, or parens). + + Args: + text: The string to search in + operator: The single-character operator to find (e.g., '+', '-', '*', '/') + + Returns: + The position of the operator at depth 0, or -1 if not found. + Returns -1 when the operator only appears inside quotes, brackets, or nested parens. + """ + state = _ParseState() + for i, ch in enumerate(text): + if ( + ch == operator + and state.paren_depth == 0 + and not state.in_quotes_or_brackets + ): + return i + state.update(ch, text[i - 1] if i > 0 else None) + return -1 + @staticmethod def _is_inside_quotes_or_brackets(text: str, pos: int) -> bool: """Check if a position in text is inside quotes or brackets.""" if pos >= len(text): return False - in_double_quote = False - in_single_quote = False - in_bracket = False - + state = _ParseState() for i in range(pos): - ch = text[i] - prev_ch = text[i - 1] if i > 0 else None - if ch == '"' and prev_ch != "\\" and not in_single_quote and not in_bracket: - in_double_quote = not in_double_quote - elif ( - ch == "'" and prev_ch != "\\" and not in_double_quote and not in_bracket - ): - in_single_quote = not in_single_quote - elif not in_double_quote and not in_single_quote: - if ch == "[": - in_bracket = True - elif ch == "]": - in_bracket = False - - return in_double_quote or in_single_quote or in_bracket + state.update(text[i], text[i - 1] if i > 0 else None) + return state.in_quotes_or_brackets @staticmethod def _find_matching_paren(text: str, start_pos: int) -> int: @@ -125,43 +181,48 @@ def _find_matching_paren(text: str, start_pos: int) -> int: if start_pos >= len(text) or text[start_pos] != "(": return -1 - in_double_quote = False - in_single_quote = False - in_bracket = False - paren_depth = 1 + state = _ParseState() + state.paren_depth = 1 # Start with depth 1 since we're at opening paren for i in range(start_pos + 1, len(text)): ch = text[i] - prev_ch = text[i - 1] if i > 0 else None - - if ch == '"' and prev_ch != "\\" and not in_single_quote and not in_bracket: - in_double_quote = not in_double_quote - elif ( - ch == "'" and prev_ch != "\\" and not in_double_quote and not in_bracket - ): - in_single_quote = not in_single_quote - elif not in_double_quote and not in_single_quote: - if ch == "[": - in_bracket = True - elif ch == "]": - in_bracket = False - elif not in_bracket: - if ch == "(": - paren_depth += 1 - elif ch == ")": - paren_depth -= 1 - if paren_depth == 0: - return i + state.update(ch, text[i - 1] if i > 0 else None) + if state.paren_depth == 0: + return i return -1 @staticmethod def _has_operators_outside_quotes(expr: str) -> bool: - """Check if expression has arithmetic operators outside of quoted strings.""" - for operator in "+-*/": - pos = KustoKqlCompiler._find_operator_outside_quotes(expr, operator) - if pos != -1: - return True - return False + """Check if expression has arithmetic operators outside of quoted strings and brackets.""" + return any( + KustoKqlCompiler._find_top_level_operator(expr, op) != -1 for op in "+-*/" + ) + + @staticmethod + def _count_outer_parens(text: str) -> tuple[int, str]: + """Count and strip outer parentheses from text. Returns (count, stripped_text).""" + text = text.strip() + count = 0 + while len(text) >= 2 and text[0] == "(" and text[-1] == ")": # noqa: PLR2004 + depth = 0 + for ch in text[:-1]: # Scan all but last char + depth += (ch == "(") - (ch == ")") + if depth == 0: + return count, text # First '(' closed before end + count += 1 + text = text[1:-1].strip() + return count, text + + @staticmethod + def _wrap_column_refs_in_parens(expr: str) -> str: + """Wrap bracket-quoted column refs in parens for arithmetic precedence, unless already wrapped.""" + + def wrap_col_ref(m: re.Match[str]) -> str: + if m.start() > 0 and expr[m.start() - 1] == "(": + return m.group(1) + return f"({m.group(1)})" + + return re.sub(r'(\["(?:[^"\\]|\\.)*"\])', wrap_col_ref, expr) @staticmethod def _extract_and_replace_aggregates( @@ -540,40 +601,39 @@ def _escape_and_quote_columns(name: str | None, is_alias=False) -> str: or KustoKqlCompiler._is_number_literal(name) ) and not is_alias: return name - # First, check if the name is already wrapped in ["ColumnName"] (escaped format) if name.startswith('["') and name.endswith('"]'): return name # Return as is if already properly escaped - # Handle arithmetic expressions by recursively escaping both sides + # Handle arithmetic expressions by recursively processing operands if not is_alias: + outer_paren_count, inner = KustoKqlCompiler._count_outer_parens(name) for operator in ["/", "+", "-", "*"]: - # Find operator that's not inside quotes - pos = KustoKqlCompiler._find_operator_outside_quotes(name, operator) + pos = KustoKqlCompiler._find_top_level_operator(inner, operator) if pos != -1: - left_part = name[:pos].strip() - right_part = name[pos + 1 :].strip() - # Recursively escape both sides - left_escaped = KustoKqlCompiler._escape_and_quote_columns(left_part) - right_escaped = KustoKqlCompiler._escape_and_quote_columns( - right_part + left = KustoKqlCompiler._escape_and_quote_columns( + inner[:pos].strip() + ) + right = KustoKqlCompiler._escape_and_quote_columns( + inner[pos + 1 :].strip() + ) + return ( + "(" * outer_paren_count + + left + + " " + + operator + + " " + + right + + ")" * outer_paren_count ) - return f"{left_escaped} {operator} {right_escaped}" + # No operators - recurse on inner content if we stripped parens + if outer_paren_count > 0: + inner_result = KustoKqlCompiler._escape_and_quote_columns(inner) + return "(" * outer_paren_count + inner_result + ")" * outer_paren_count # No operators found - strip surrounding quotes if present, then wrap if name.startswith('"') and name.endswith('"'): name = name[1:-1] name = name.replace('"', '\\"') return f'["{name}"]' - @staticmethod - def _find_operator_outside_quotes(text: str, operator: str) -> int: - """Find position of operator that's not inside quoted strings. Returns -1 if not found.""" - in_quotes = False - for i, ch in enumerate(text): - if ch == '"' and (i == 0 or text[i - 1] != "\\"): - in_quotes = not in_quotes - elif not in_quotes and ch == operator: - return i - return -1 - @staticmethod def _sql_to_kql_where(where_clause: str) -> str: where_clause = where_clause.strip().replace("\n", "") diff --git a/tests/unit/test_dialect_kql.py b/tests/unit/test_dialect_kql.py index 89692d3..e0fec45 100644 --- a/tests/unit/test_dialect_kql.py +++ b/tests/unit/test_dialect_kql.py @@ -429,8 +429,7 @@ def test_select_count(): 'let inner_qry = (["logs"]);' "inner_qry" "| where Field1 > 1 and Field2 < 2" - '| summarize ["__total-count_1"] = count() ' - '| extend ["total-count"] = ["__total-count_1"]' + '| summarize ["total-count"] = count() ' '| project ["total-count"]' '| order by ["total-count"] desc' "| take 5" @@ -631,7 +630,7 @@ def test_calculated_measure_with_adhoc_measure_and_constant(): """Test calculated measure with an ad hoc measure and a constant. Measure 1 = count(*), Measure 2 = "Measure 1" * 2 - Measure 2 should compile to (["Measure 1"]) * 2 (parentheses for arithmetic precedence) + Measure 2 should compile to ["Measure 1"] * 2 (references the predefined measure) """ measure_1 = literal_column("count(*)").label("Measure 1") measure_2 = literal_column('"Measure 1" * 2').label("Measure 2") @@ -641,8 +640,8 @@ def test_calculated_measure_with_adhoc_measure_and_constant(): ).replace("\n", "") query_expected = ( '["SalesData"]' - '| summarize ["__Measure 1_1"] = count() ' - '| extend ["Measure 1"] = ["__Measure 1_1"], ["Measure 2"] = ["Measure 1"] * 2' + '| summarize ["Measure 1"] = count() ' + '| extend ["Measure 2"] = ["Measure 1"] * 2' '| project ["Measure 1"], ["Measure 2"]' ) assert query_compiled == query_expected @@ -652,7 +651,7 @@ def test_calculated_measure_with_two_adhoc_measures_and_aggregates(): """Test calculated measure referencing two ad hoc measures with aggregates. Measure 1 = count(*), Measure 2 = count(*) - Measure 3 = "Measure 1" + "Measure 2" should compile to (["Measure 1"]) + (["Measure 2"]) + Measure 3 = "Measure 1" + "Measure 2" should compile to ["Measure 1"] + ["Measure 2"] """ measure_1 = literal_column("count(*)").label("Measure 1") measure_2 = literal_column("count(*)").label("Measure 2") @@ -663,8 +662,8 @@ def test_calculated_measure_with_two_adhoc_measures_and_aggregates(): ).replace("\n", "") query_expected = ( '["SalesData"]' - '| summarize ["__Measure 1_1"] = count() ' - '| extend ["Measure 1"] = ["__Measure 1_1"], ["Measure 2"] = ["__Measure 1_1"], ["Measure 3"] = ["Measure 1"] + ["Measure 2"]' + '| summarize ["Measure 1"] = count(), ["Measure 2"] = count() ' + '| extend ["Measure 3"] = ["Measure 1"] + ["Measure 2"]' '| project ["Measure 1"], ["Measure 2"], ["Measure 3"]' ) assert query_compiled == query_expected @@ -709,8 +708,8 @@ def test_calculated_measure_with_mixed_aggregates_and_references(): ).replace("\n", "") query_expected = ( '["SalesData"]' - '| summarize ["__Predefined 1_1"] = count(), ["__Calculated_1"] = count(["b"]) ' - '| extend ["Predefined 1"] = ["__Predefined 1_1"], ["Calculated"] = ["Predefined 1"] + ["__Calculated_1"]' + '| summarize ["Predefined 1"] = count(), ["__Calculated_1"] = count(["b"]) ' + '| extend ["Calculated"] = ["Predefined 1"] + ["__Calculated_1"]' '| project ["Predefined 1"], ["Calculated"]' ) assert query_compiled == query_expected From 51e16e847087a9e91b7b76b31f7c17eedcbd55ff Mon Sep 17 00:00:00 2001 From: Alison Gim Date: Tue, 10 Feb 2026 23:34:29 +0000 Subject: [PATCH 2/3] Add named constants for test values to satisfy PLR2004 linting rule - Add position constants to TestFindTopLevelOperator class - Add paren count constants to TestCountOuterParens class - Add closing position constants to TestFindMatchingParen class - Add aggregate count constants to TestExtractAndReplaceAggregates class - Add integration tests for parentheses preservation, quoted identifiers, uppercase functions, and extend after summarize --- sqlalchemy_kusto/dialect_kql.py | 16 +- tests/integration/test_dialect_kql.py | 88 +++++++ tests/unit/test_dialect_kql.py | 329 ++++++++++++++++++++++++-- 3 files changed, 400 insertions(+), 33 deletions(-) diff --git a/sqlalchemy_kusto/dialect_kql.py b/sqlalchemy_kusto/dialect_kql.py index a8e7652..5043819 100644 --- a/sqlalchemy_kusto/dialect_kql.py +++ b/sqlalchemy_kusto/dialect_kql.py @@ -213,17 +213,6 @@ def _count_outer_parens(text: str) -> tuple[int, str]: text = text[1:-1].strip() return count, text - @staticmethod - def _wrap_column_refs_in_parens(expr: str) -> str: - """Wrap bracket-quoted column refs in parens for arithmetic precedence, unless already wrapped.""" - - def wrap_col_ref(m: re.Match[str]) -> str: - if m.start() > 0 and expr[m.start() - 1] == "(": - return m.group(1) - return f"({m.group(1)})" - - return re.sub(r'(\["(?:[^"\\]|\\.)*"\])', wrap_col_ref, expr) - @staticmethod def _extract_and_replace_aggregates( expr: str, measure_name: str, existing_aggs: dict[str, str] | None = None @@ -624,10 +613,9 @@ def _escape_and_quote_columns(name: str | None, is_alias=False) -> str: + right + ")" * outer_paren_count ) - # No operators - recurse on inner content if we stripped parens + # No operators - just process inner content (don't re-add unnecessary parens) if outer_paren_count > 0: - inner_result = KustoKqlCompiler._escape_and_quote_columns(inner) - return "(" * outer_paren_count + inner_result + ")" * outer_paren_count + return KustoKqlCompiler._escape_and_quote_columns(inner) # No operators found - strip surrounding quotes if present, then wrap if name.startswith('"') and name.endswith('"'): name = name[1:-1] diff --git a/tests/integration/test_dialect_kql.py b/tests/integration/test_dialect_kql.py index 5410337..028d0da 100644 --- a/tests/integration/test_dialect_kql.py +++ b/tests/integration/test_dialect_kql.py @@ -202,6 +202,94 @@ def test_date_bin_ops(test_label, group_fn, temp_table_name, expected, compare_d assert actual_result == expected_records +def test_parentheses_preserved_in_expression(temp_table_name): + table = Table( + temp_table_name, + metadata, + Column("Text", String), + ) + measure = literal_column('(("Text"))').label("Measure 1") + query = ( + session.query(measure) + .select_from(table) + .order_by(literal_column('"Measure 1"')) + ) + query_compiled = str(query.statement.compile(kql_engine)).replace("\n", "") + + assert "((" in query_compiled + assert "))" in query_compiled + assert '["Text"]' in query_compiled + + with kql_engine.connect() as connection: + result = connection.execute(text(query_compiled)) + values = [row[0] for row in result.fetchall()] + expected_count = 9 + assert len(values) == expected_count + assert set(values) == {"value_0", "value_1"} + + +def test_quoted_identifier_converted_to_kql(temp_table_name): + table = Table( + temp_table_name, + metadata, + Column("Text", String), + ) + quoted_text = literal_column("Text").label("QuotedText") + query = session.query(quoted_text).select_from(table).order_by(text("QuotedText")) + query_compiled = str(query.statement.compile(kql_engine)).replace("\n", "") + + assert '["Text"]' in query_compiled + + with kql_engine.connect() as connection: + result = connection.execute(text(query_compiled)) + values = [row[0] for row in result.fetchall()] + expected_count = 9 + assert len(values) == expected_count + assert set(values) == {"value_0", "value_1"} + + +def test_uppercase_functions_lowercased(temp_table_name): + table = Table( + temp_table_name, + metadata, + ) + query = session.query(func.COUNT(text("Id")).label("tag_count")).select_from(table) + query_compiled = str(query.statement.compile(kql_engine)).replace("\n", "") + + assert "count(" in query_compiled + assert "COUNT(" not in query_compiled + + with kql_engine.connect() as connection: + result = connection.execute(text(query_compiled)) + assert {row[0] for row in result.fetchall()} == {9} + + +def test_extend_after_summarize_for_calculated_measures(temp_table_name): + table = Table( + temp_table_name, + metadata, + ) + measure_1 = func.COUNT(text("Id")).label("Measure 1") + measure_2 = literal_column('"Measure 1" * 2').label("Measure 2") + query = session.query(measure_1, measure_2).select_from(table) + query_compiled = str(query.statement.compile(kql_engine)).replace("\n", "") + + summarize_index = query_compiled.find("| summarize") + extend_index = query_compiled.find("| extend") + assert summarize_index != -1 + assert extend_index != -1 + assert summarize_index < extend_index + + with kql_engine.connect() as connection: + result = connection.execute(text(query_compiled)) + row = result.fetchone() + assert row is not None + expected_count = 9 + expected_double_count = 18 + assert int(row[0]) == expected_count + assert int(row[1]) == expected_double_count + + def get_kcsb(): return ( KustoConnectionStringBuilder.with_az_cli_authentication(KUSTO_URL) diff --git a/tests/unit/test_dialect_kql.py b/tests/unit/test_dialect_kql.py index e0fec45..ff92bf6 100644 --- a/tests/unit/test_dialect_kql.py +++ b/tests/unit/test_dialect_kql.py @@ -565,6 +565,127 @@ def test_escape_and_quote_columns_preserves_already_bracketed(): assert result == '["Measure 1"]' +class TestCalculatedMeasuresWithParentheses: + """Tests for calculated measures with parentheses support.""" + + @pytest.fixture + def pt_search_table(self): + """Table matching the Superset PT_Search_scenario use case.""" + metadata = MetaData() + return Table( + "PT_Search_scenario", + metadata, + Column("UserInfo_Ring", String), + Column("UserInfo_Region", String), + schema="test_schema", + ) + + def test_calculated_measure_single_paren(self, pt_search_table): + """Test a calculated measure with single parentheses wrapper.""" + measure_16 = literal_column('("UserInfo_Ring Count")').label("Measure 16") + + query = select(measure_16).select_from(pt_search_table) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + assert '["Measure 16"]' in compiled + + def test_calculated_measure_double_paren(self, pt_search_table): + """Test a calculated measure with double parentheses wrapper. + + Double parens around a single measure reference are stripped since + they're unnecessary for precedence. + """ + measure_3 = literal_column('(("Measure 1"))').label("Measure 3") + + query = select(measure_3).select_from(pt_search_table) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + assert '["Measure 3"]' in compiled + # Parens should be stripped for single values (no operators inside) + assert '["Measure 1"]' in compiled + + def test_calculated_measure_parens_addition(self, pt_search_table): + """Test a calculated measure with parenthesized addition.""" + measure_11 = literal_column('("Measure 1") + ("Measure 2")').label("Measure 11") + + query = select(measure_11).select_from(pt_search_table) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + assert '["Measure 11"]' in compiled + assert "+" in compiled + + def test_calculated_measure_complex_expression(self, pt_search_table): + """Test a complex calculated measure with nested parens and multiplication.""" + measure_8 = literal_column( + '("UserInfo_Ring Count" + "UserInfo_Region Count") * 2' + ).label("Measure 8") + + query = select(measure_8).select_from(pt_search_table) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + assert '["Measure 8"]' in compiled + assert "* 2" in compiled + assert "+" in compiled + + def test_no_double_bracketing(self, pt_search_table): + """Test that there's no double bracketing like [["col"]].""" + measure = literal_column('"UserInfo_Ring Count"').label("Test Measure") + + query = select(measure).select_from(pt_search_table) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + # Should not have double brackets + assert '[["' not in compiled + assert '"]]' not in compiled + + def test_wrapped_aggregate_extracted_correctly(self, pt_search_table): + """Test that aggregates wrapped in parens (like ((COUNT(col)))) are extracted correctly.""" + measure_4 = literal_column("((COUNT(UserInfo_Ring)))").label("Measure 4") + + query = select(measure_4).select_from(pt_search_table) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + # Should have summarize with the aggregate + assert "summarize" in compiled + + # Find the extend part + extend_idx = compiled.find("extend") + if extend_idx != -1: + project_idx = compiled.find("| project") + extend_part = ( + compiled[extend_idx:project_idx] + if project_idx != -1 + else compiled[extend_idx:] + ) + + # Should NOT have COUNT() in extend + assert "COUNT(" not in extend_part + assert "count(" not in extend_part + # Should have a reference + assert '["Measure 4"]' in extend_part + + def test_floating_point_numbers(self, pt_search_table): + """Test that floating point numbers are preserved correctly.""" + measure_1 = literal_column("count()").label("Measure 1") + measure_2 = literal_column('"Measure 1" * 0.5').label("Measure 2") + measure_3 = literal_column('"Measure 1" * 1.25').label("Measure 3") + measure_4 = literal_column('"Measure 1" / 0.1').label("Measure 4") + + query = select(measure_1, measure_2, measure_3, measure_4).select_from( + pt_search_table + ) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + # Floating point numbers should be preserved, not wrapped in brackets + assert "* 0.5" in compiled + assert "* 1.25" in compiled + assert "/ 0.1" in compiled + # Should NOT have bracketed numbers + assert '["0.5"]' not in compiled + assert '["1.25"]' not in compiled + assert '["0.1"]' not in compiled + + def test_calculated_measure_with_two_adhoc_measures(): """Test calculated measure referencing two ad hoc measures. @@ -720,6 +841,156 @@ def test_calculated_measure_with_mixed_aggregates_and_references(): # ============================================================================ +class TestFindTopLevelOperator: + """Tests for _find_top_level_operator helper.""" + + NOT_FOUND = -1 + # Expected positions for operator in various test strings + POS_AFTER_SPACE_CHAR = 2 # "a + b" -> operator at index 2 + POS_AFTER_PARENS = 4 # "(a) + (b)" -> operator at index 4 + POS_AFTER_DOUBLE_PARENS = 6 # "((a)) + ((b))" -> operator at index 6 + POS_AFTER_ESCAPED_QUOTE = 7 # '"a\"b" + c' -> operator at index 7 + POS_MINUS = 6 # "a + b - c * d / e" -> minus at index 6 + POS_MULT = 10 # "a + b - c * d / e" -> mult at index 10 + POS_DIV = 14 # "a + b - c * d / e" -> div at index 14 + + def test_finds_operator_at_start(self): + assert KustoKqlCompiler._find_top_level_operator("+ b", "+") == 0 + + def test_finds_operator_in_middle(self): + assert ( + KustoKqlCompiler._find_top_level_operator("a + b", "+") + == self.POS_AFTER_SPACE_CHAR + ) + + def test_finds_operator_at_end(self): + assert ( + KustoKqlCompiler._find_top_level_operator("a +", "+") + == self.POS_AFTER_SPACE_CHAR + ) + + def test_not_found_returns_minus_one(self): + assert KustoKqlCompiler._find_top_level_operator("a b", "+") == self.NOT_FOUND + + def test_operator_inside_double_quotes_not_found(self): + assert ( + KustoKqlCompiler._find_top_level_operator('"a + b"', "+") == self.NOT_FOUND + ) + + def test_operator_inside_single_quotes_not_found(self): + assert ( + KustoKqlCompiler._find_top_level_operator("'a + b'", "+") == self.NOT_FOUND + ) + + def test_operator_inside_brackets_not_found(self): + assert ( + KustoKqlCompiler._find_top_level_operator('["a + b"]', "+") + == self.NOT_FOUND + ) + + def test_operator_inside_parens_not_found(self): + assert ( + KustoKqlCompiler._find_top_level_operator("(a + b)", "+") == self.NOT_FOUND + ) + + def test_finds_operator_outside_parens(self): + result = KustoKqlCompiler._find_top_level_operator("(a) + (b)", "+") + assert result == self.POS_AFTER_PARENS + + def test_finds_operator_with_nested_parens(self): + result = KustoKqlCompiler._find_top_level_operator("((a)) + ((b))", "+") + assert result == self.POS_AFTER_DOUBLE_PARENS + + def test_mixed_quotes_and_operator(self): + result = KustoKqlCompiler._find_top_level_operator("\"col\" + 'value'", "+") + assert result == self.POS_AFTER_DOUBLE_PARENS + + def test_operator_after_escaped_quote(self): + # Escaped quote should not affect detection + text = r'"a\"b" + c' + assert ( + KustoKqlCompiler._find_top_level_operator(text, "+") + == self.POS_AFTER_ESCAPED_QUOTE + ) + + def test_multiple_operators_finds_first(self): + result = KustoKqlCompiler._find_top_level_operator("a + b + c", "+") + assert result == self.POS_AFTER_SPACE_CHAR + + def test_different_operator_types(self): + text = "a + b - c * d / e" + assert ( + KustoKqlCompiler._find_top_level_operator(text, "+") + == self.POS_AFTER_SPACE_CHAR + ) + assert KustoKqlCompiler._find_top_level_operator(text, "-") == self.POS_MINUS + assert KustoKqlCompiler._find_top_level_operator(text, "*") == self.POS_MULT + assert KustoKqlCompiler._find_top_level_operator(text, "/") == self.POS_DIV + + def test_empty_string(self): + assert KustoKqlCompiler._find_top_level_operator("", "+") == self.NOT_FOUND + + +class TestCountOuterParens: + """Tests for _count_outer_parens helper.""" + + ZERO_PARENS = 0 + ONE_PAREN = 1 + TWO_PARENS = 2 + THREE_PARENS = 3 + + def test_no_parens(self): + count, inner = KustoKqlCompiler._count_outer_parens("a + b") + assert count == self.ZERO_PARENS + assert inner == "a + b" + + def test_single_outer_paren(self): + count, inner = KustoKqlCompiler._count_outer_parens("(a + b)") + assert count == self.ONE_PAREN + assert inner == "a + b" + + def test_double_outer_parens(self): + count, inner = KustoKqlCompiler._count_outer_parens("((a + b))") + assert count == self.TWO_PARENS + assert inner == "a + b" + + def test_triple_outer_parens(self): + count, inner = KustoKqlCompiler._count_outer_parens("(((x)))") + assert count == self.THREE_PARENS + assert inner == "x" + + def test_parens_not_matching(self): + # (a) + (b) - first paren doesn't wrap the whole expression + count, inner = KustoKqlCompiler._count_outer_parens("(a) + (b)") + assert count == self.ZERO_PARENS + assert inner == "(a) + (b)" + + def test_mixed_outer_and_inner(self): + count, inner = KustoKqlCompiler._count_outer_parens("((a + (b)))") + assert count == self.TWO_PARENS + assert inner == "a + (b)" + + def test_with_whitespace(self): + count, inner = KustoKqlCompiler._count_outer_parens(" ( (x) ) ") + assert count == self.TWO_PARENS + assert inner == "x" + + def test_empty_string(self): + count, inner = KustoKqlCompiler._count_outer_parens("") + assert count == self.ZERO_PARENS + assert inner == "" + + def test_single_char(self): + count, inner = KustoKqlCompiler._count_outer_parens("x") + assert count == self.ZERO_PARENS + assert inner == "x" + + def test_parens_only(self): + count, inner = KustoKqlCompiler._count_outer_parens("()") + assert count == self.ONE_PAREN + assert inner == "" + + class TestIsInsideQuotesOrBrackets: """Tests for _is_inside_quotes_or_brackets helper.""" @@ -778,42 +1049,58 @@ def test_out_of_bounds_position(self): class TestFindMatchingParen: """Tests for _find_matching_paren helper.""" + NOT_FOUND = -1 + # Expected closing paren positions for various test strings + SIMPLE_CLOSE = 7 # "count(x)" -> closing paren at 7 + OUTER_CLOSE = 12 # "sum(count(x))" -> outer closing at 12 + INNER_CLOSE = 11 # "sum(count(x))" -> inner closing at 11 + QUOTED_CLOSE = 23 # 'func("text(with)parens")' -> closing at 23 + BRACKETED_CLOSE = 15 # 'func(["col(1)"])' -> closing at 15 + EMPTY_CLOSE = 5 # "func()" -> closing at 5 + DEEPLY_NESTED_OUTER = 9 # "a(b(c(d)))" -> outermost closing at 9 + DEEPLY_NESTED_MID = 8 # "a(b(c(d)))" -> middle closing at 8 + DEEPLY_NESTED_INNER = 7 # "a(b(c(d)))" -> innermost closing at 7 + def test_simple_parentheses(self): text = "count(x)" - assert KustoKqlCompiler._find_matching_paren(text, 5) == 7 + assert KustoKqlCompiler._find_matching_paren(text, 5) == self.SIMPLE_CLOSE def test_nested_parentheses(self): text = "sum(count(x))" - assert KustoKqlCompiler._find_matching_paren(text, 3) == 12 # Outer paren - assert KustoKqlCompiler._find_matching_paren(text, 9) == 11 # Inner paren + assert KustoKqlCompiler._find_matching_paren(text, 3) == self.OUTER_CLOSE + assert KustoKqlCompiler._find_matching_paren(text, 9) == self.INNER_CLOSE def test_parentheses_with_quotes(self): # Parens inside quotes should be ignored text = 'func("text(with)parens")' - assert KustoKqlCompiler._find_matching_paren(text, 4) == 23 + assert KustoKqlCompiler._find_matching_paren(text, 4) == self.QUOTED_CLOSE def test_parentheses_with_brackets(self): # Parens inside brackets should be ignored text = 'func(["col(1)"])' - assert KustoKqlCompiler._find_matching_paren(text, 4) == 15 + assert KustoKqlCompiler._find_matching_paren(text, 4) == self.BRACKETED_CLOSE def test_no_opening_paren_at_position(self): text = "no paren here" - assert KustoKqlCompiler._find_matching_paren(text, 0) == -1 + assert KustoKqlCompiler._find_matching_paren(text, 0) == self.NOT_FOUND def test_unmatched_parenthesis(self): text = "func(x" - assert KustoKqlCompiler._find_matching_paren(text, 4) == -1 + assert KustoKqlCompiler._find_matching_paren(text, 4) == self.NOT_FOUND def test_empty_parentheses(self): text = "func()" - assert KustoKqlCompiler._find_matching_paren(text, 4) == 5 + assert KustoKqlCompiler._find_matching_paren(text, 4) == self.EMPTY_CLOSE def test_deeply_nested(self): text = "a(b(c(d)))" - assert KustoKqlCompiler._find_matching_paren(text, 1) == 9 - assert KustoKqlCompiler._find_matching_paren(text, 3) == 8 - assert KustoKqlCompiler._find_matching_paren(text, 5) == 7 + assert ( + KustoKqlCompiler._find_matching_paren(text, 1) == self.DEEPLY_NESTED_OUTER + ) + assert KustoKqlCompiler._find_matching_paren(text, 3) == self.DEEPLY_NESTED_MID + assert ( + KustoKqlCompiler._find_matching_paren(text, 5) == self.DEEPLY_NESTED_INNER + ) class TestHasOperatorsOutsideQuotes: @@ -853,13 +1140,17 @@ def test_aggregate_with_operators(self): class TestExtractAndReplaceAggregates: """Tests for _extract_and_replace_aggregates helper.""" + ZERO_AGGS = 0 + ONE_AGG = 1 + TWO_AGGS = 2 + def test_single_aggregate(self): expr = 'count(["a"])' result, new_aggs = KustoKqlCompiler._extract_and_replace_aggregates( expr, "Measure" ) assert result == '["__Measure_1"]' - assert len(new_aggs) == 1 + assert len(new_aggs) == self.ONE_AGG assert new_aggs[0] == ('["__Measure_1"]', 'count(["a"])') def test_two_aggregates_with_operator(self): @@ -868,7 +1159,7 @@ def test_two_aggregates_with_operator(self): expr, "Measure" ) assert result == '["__Measure_1"] + ["__Measure_2"]' - assert len(new_aggs) == 2 + assert len(new_aggs) == self.TWO_AGGS assert new_aggs[0][1] == 'count(["a"])' assert new_aggs[1][1] == 'sum(["b"])' @@ -878,7 +1169,7 @@ def test_no_aggregates(self): expr, "Measure" ) assert result == '["a"] + ["b"]' - assert len(new_aggs) == 0 + assert len(new_aggs) == self.ZERO_AGGS def test_reuses_existing_aggregate(self): expr = 'count(["a"]) + count(["a"])' @@ -887,7 +1178,7 @@ def test_reuses_existing_aggregate(self): ) # Both should use the same reference assert result == '["__Measure_1"] + ["__Measure_1"]' - assert len(new_aggs) == 1 # Only one unique aggregate + assert len(new_aggs) == self.ONE_AGG # Only one unique aggregate def test_existing_aggs_parameter(self): existing = {'count(["a"])': '["existing_ref"]'} @@ -896,7 +1187,7 @@ def test_existing_aggs_parameter(self): expr, "Measure", existing ) assert '["existing_ref"]' in result - assert len(new_aggs) == 1 # Only sum is new, count is reused + assert len(new_aggs) == self.ONE_AGG # Only sum is new, count is reused def test_aggregate_in_quotes_ignored(self): expr = '"count(a)"' @@ -904,7 +1195,7 @@ def test_aggregate_in_quotes_ignored(self): expr, "Measure" ) assert result == '"count(a)"' - assert len(new_aggs) == 0 + assert len(new_aggs) == self.ZERO_AGGS def test_aggregate_in_brackets_ignored(self): expr = '["count(a)"]' @@ -912,7 +1203,7 @@ def test_aggregate_in_brackets_ignored(self): expr, "Measure" ) assert result == '["count(a)"]' - assert len(new_aggs) == 0 + assert len(new_aggs) == self.ZERO_AGGS def test_measure_name_with_special_chars(self): expr = 'count(["a"])' @@ -927,7 +1218,7 @@ def test_complex_expression(self): assert '["__Pct_1"]' in result assert '["__Pct_2"]' in result assert "* 100 /" in result - assert len(new_aggs) == 2 + assert len(new_aggs) == self.TWO_AGGS class TestContainsAggregateFunction: From 98da2eb00bcdcd2e5c4f67777c0a84481a7331a4 Mon Sep 17 00:00:00 2001 From: Alison Gim Date: Wed, 11 Feb 2026 00:40:25 +0000 Subject: [PATCH 3/3] Refactor: use > 1 instead of >= 2 to avoid PLR2004 noqa comment --- sqlalchemy_kusto/dialect_kql.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sqlalchemy_kusto/dialect_kql.py b/sqlalchemy_kusto/dialect_kql.py index 5043819..cbc82d4 100644 --- a/sqlalchemy_kusto/dialect_kql.py +++ b/sqlalchemy_kusto/dialect_kql.py @@ -203,7 +203,7 @@ def _count_outer_parens(text: str) -> tuple[int, str]: """Count and strip outer parentheses from text. Returns (count, stripped_text).""" text = text.strip() count = 0 - while len(text) >= 2 and text[0] == "(" and text[-1] == ")": # noqa: PLR2004 + while len(text) > 1 and text[0] == "(" and text[-1] == ")": depth = 0 for ch in text[:-1]: # Scan all but last char depth += (ch == "(") - (ch == ")") @@ -590,6 +590,7 @@ def _escape_and_quote_columns(name: str | None, is_alias=False) -> str: or KustoKqlCompiler._is_number_literal(name) ) and not is_alias: return name + # First, check if the name is already wrapped in ["ColumnName"] (escaped format) if name.startswith('["') and name.endswith('"]'): return name # Return as is if already properly escaped # Handle arithmetic expressions by recursively processing operands