diff --git a/core/ast/__init__.py b/core/ast/__init__.py index 1474504..5f7b0dd 100644 --- a/core/ast/__init__.py +++ b/core/ast/__init__.py @@ -13,6 +13,7 @@ LiteralNode, ElementVariableNode, SetVariableNode, + VariableLiteralNode, OperatorNode, FunctionNode, SelectNode, @@ -37,6 +38,7 @@ 'LiteralNode', 'ElementVariableNode', 'SetVariableNode', + 'VariableLiteralNode', 'OperatorNode', 'FunctionNode', 'SelectNode', diff --git a/core/ast/enums.py b/core/ast/enums.py index 63f79dc..fa53cd3 100644 --- a/core/ast/enums.py +++ b/core/ast/enums.py @@ -20,6 +20,7 @@ class NodeType(Enum): # VarSQL specific VAR = "var" VARSET = "varset" + VAR_LITERAL = "var_literal" # Operators OPERATOR = "operator" diff --git a/core/ast/node.py b/core/ast/node.py index b0d72ca..0eddc76 100644 --- a/core/ast/node.py +++ b/core/ast/node.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Set, Optional, Union +from typing import List, Set, Optional, Tuple, Union from abc import ABC from .enums import NodeType, JoinType, SortOrder @@ -46,7 +46,7 @@ def __hash__(self): class TableNode(Node): """Table reference node""" - def __init__(self, _name: str, _alias: Optional[str] = None, **kwargs): + def __init__(self, _name: str, _alias: Optional[Union[str, 'ElementVariableNode']] = None, **kwargs): super().__init__(NodeType.TABLE, **kwargs) self.name = _name self.alias = _alias @@ -80,7 +80,7 @@ def __hash__(self): class ColumnNode(Node): """Column reference node""" - def __init__(self, _name: str, _alias: Optional[str] = None, _parent_alias: Optional[str] = None, _parent: Optional[TableNode|SubqueryNode] = None, **kwargs): + def __init__(self, _name: str, _alias: Optional[Union[str, 'ElementVariableNode']] = None, _parent_alias: Optional[Union[str, 'ElementVariableNode']] = None, _parent: Optional[TableNode|SubqueryNode] = None, **kwargs): super().__init__(NodeType.COLUMN, **kwargs) self.name = _name self.alias = _alias @@ -101,18 +101,20 @@ def __hash__(self): class LiteralNode(Node): """Literal value node""" - def __init__(self, _value: str|int|float|bool|datetime|None, **kwargs): + def __init__(self, _value: str|int|float|bool|datetime|None, _alias: Optional[str] = None, **kwargs): super().__init__(NodeType.LITERAL, **kwargs) self.value = _value + self.alias = _alias def __eq__(self, other): if not isinstance(other, LiteralNode): return False return (super().__eq__(other) and - self.value == other.value) + self.value == other.value and + self.alias == other.alias) def __hash__(self): - return hash((super().__hash__(), self.value)) + return hash((super().__hash__(), self.value, self.alias)) class DataTypeNode(Node): """SQL data type node used in CAST expressions (e.g. TEXT, DATE, INTEGER)""" @@ -170,17 +172,19 @@ def __hash__(self): class ElementVariableNode(Node): """Rule element variable ```` (see ``VarType.ElementVariable`` in rule_parser_v2).""" - def __init__(self, _name: str, **kwargs): + def __init__(self, _name: str, parent_alias: Optional[Union[str, 'ElementVariableNode']] = None, alias: Optional[Union[str, 'ElementVariableNode']] = None, **kwargs): super().__init__(NodeType.VAR, **kwargs) self.name = _name + self.parent_alias = parent_alias + self.alias = alias def __eq__(self, other): if not isinstance(other, ElementVariableNode): return False - return super().__eq__(other) and self.name == other.name + return super().__eq__(other) and self.name == other.name and self.parent_alias == other.parent_alias and self.alias == other.alias def __hash__(self): - return hash((super().__hash__(), self.name)) + return hash((super().__hash__(), self.name, self.parent_alias, self.alias)) class SetVariableNode(Node): @@ -198,6 +202,31 @@ def __hash__(self): return hash((super().__hash__(), self.name)) +class VariableLiteralNode(Node): + """A string literal placeholder, e.g. ``'%%'`` in a LIKE predicate. + + ``prefix`` and ``suffix`` capture surrounding wildcard characters so + ``LIKE '%foo%'`` → ``VariableLiteralNode('x1', prefix='%', suffix='%')``. + """ + def __init__(self, _name: str, prefix: str = "", suffix: str = "", + _alias: Optional[str] = None, **kwargs): + super().__init__(NodeType.VAR_LITERAL, **kwargs) + self.name = _name + self.prefix = prefix + self.suffix = suffix + self.alias = _alias + + def __eq__(self, other): + if not isinstance(other, VariableLiteralNode): + return False + return (super().__eq__(other) and self.name == other.name + and self.prefix == other.prefix and self.suffix == other.suffix + and self.alias == other.alias) + + def __hash__(self): + return hash((super().__hash__(), self.name, self.prefix, self.suffix, self.alias)) + + class OperatorNode(Node): """Operator node""" def __init__(self, _left: Node, _name: str, _right: Optional[Node] = None, **kwargs): @@ -229,7 +258,7 @@ def __init__(self, _operand: Node, _name: str, **kwargs): class FunctionNode(Node): """Function call node""" - def __init__(self, _name: str, _args: Optional[List[Node]] = None, _alias: Optional[str] = None, **kwargs): + def __init__(self, _name: str, _args: Optional[List[Node]] = None, _alias: Optional[Union[str, 'ElementVariableNode']] = None, **kwargs): if _args is None: _args = [] super().__init__(NodeType.FUNCTION, children=_args, **kwargs) @@ -249,24 +278,31 @@ def __hash__(self): class JoinNode(Node): """JOIN clause node""" - def __init__(self, _left_table: Union['TableNode', 'JoinNode', 'SubqueryNode'], _right_table: Union['TableNode', 'SubqueryNode'], _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, **kwargs): + def __init__(self, _left_table: Union['TableNode', 'JoinNode', 'SubqueryNode'], _right_table: Union['TableNode', 'SubqueryNode'], _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, _using: Optional[List['Node']] = None, **kwargs): children = [_left_table, _right_table] if _on_condition: children.append(_on_condition) + if _using: + children.extend(_using) super().__init__(NodeType.JOIN, children=children, **kwargs) self.left_table = _left_table self.right_table = _right_table self.join_type = _join_type self.on_condition = _on_condition - + self.using = list(_using) if _using else None + def __eq__(self, other): if not isinstance(other, JoinNode): return False return (super().__eq__(other) and - self.join_type == other.join_type) + self.join_type == other.join_type and + self.using == other.using) def __hash__(self): - return hash((super().__hash__(), self.join_type)) + using_key: Tuple = () + if self.using: + using_key = tuple(self.using) + return hash((super().__hash__(), self.join_type, using_key)) # ============================================================================ # Query Structure Nodes @@ -463,4 +499,4 @@ def __eq__(self, other): return super().__eq__(other) and self.whens == other.whens and self.else_val == other.else_val def __hash__(self): - return hash((super().__hash__(), tuple(self.whens), self.else_val)) \ No newline at end of file + return hash((super().__hash__(), tuple(self.whens), self.else_val)) diff --git a/core/query_formatter.py b/core/query_formatter.py index 32acfcd..1af1bb4 100644 --- a/core/query_formatter.py +++ b/core/query_formatter.py @@ -12,11 +12,33 @@ OrderByNode, JoinNode, SubqueryNode, + ElementVariableNode, + SetVariableNode, + VariableLiteralNode, ) from core.ast.enums import NodeType, JoinType from core.ast.node import Node from core.ast.utils import flatten_logical_operands + +def _normalize_placeholder_tokens(sql: str) -> str: + out = re.sub(r"__rvs_(\w+)__", r"<<\1>>", sql) + out = re.sub(r"__rv_(\w+)__", r"<\1>", out) + return out + + +def _render_alias(alias) -> str: + """Render an alias/parent_alias field to a string for mosql output. + + Concrete string aliases pass through unchanged; ElementVariableNode aliases + emit their placeholder token (``__rv_name__``), which ``_normalize_placeholder_tokens`` + later converts to ````. + """ + if isinstance(alias, ElementVariableNode): + return f"__rv_{alias.name}__" + return alias + + class QueryFormatter: def format(self, query: Node) -> str: # [1] AST -> JSON @@ -26,8 +48,9 @@ def format(self, query: Node) -> str: sql = mosql.format(json_query) # Fixes edge case where formatting json with INTERVAL '0' SECOND into SQL adds quotes - sql = re.sub(r"INTERVAL '(\d+)'", r'INTERVAL \1', sql) - + sql = re.sub(r"INTERVAL '(\d+)'", r'INTERVAL \1', sql) + + sql = _normalize_placeholder_tokens(sql) return sql def _collect_union_branches(node: CompoundQueryNode, is_all: bool) -> list: @@ -83,9 +106,19 @@ def ast_to_json(node: Node) -> dict: elif child.type == NodeType.ORDER_BY: result['orderby'] = format_order_by(child) elif child.type == NodeType.LIMIT: - result['limit'] = child.limit + lv = child.limit + if isinstance(lv, ElementVariableNode): + lv = f"__rv_{lv.name}__" + elif isinstance(lv, SetVariableNode): + lv = f"__rvs_{lv.name}__" + result['limit'] = lv elif child.type == NodeType.OFFSET: - result['offset'] = child.offset + ov = child.offset + if isinstance(ov, ElementVariableNode): + ov = f"__rv_{ov.name}__" + elif isinstance(ov, SetVariableNode): + ov = f"__rvs_{ov.name}__" + result['offset'] = ov return result @@ -105,19 +138,10 @@ def format_select(select_node: SelectNode) -> dict: items = [] for child in children: - if child.type == NodeType.COLUMN: - if child.alias: - items.append({'name': child.alias, 'value': format_expression(child)}) - else: - items.append({'value': format_expression(child)}) - elif child.type == NodeType.FUNCTION: - func_expr = format_expression(child) - if hasattr(child, 'alias') and child.alias: - items.append({'name': child.alias, 'value': func_expr}) - else: - items.append({'value': func_expr}) - else: - items.append({'value': format_expression(child)}) + item = {'value': format_expression(child)} + if hasattr(child, 'alias') and child.alias: + item['name'] = _render_alias(child.alias) + items.append(item) select_key = 'select_distinct' if select_node.distinct else 'select' result[select_key] = items @@ -173,29 +197,20 @@ def format_from(from_node: FromNode): def format_join(join_node: JoinNode) -> list: """Format a JOIN node""" - children = list(join_node.children) - - if len(children) < 2: - raise ValueError("JoinNode must have at least 2 children (left and right tables)") - - left_node = children[0] - right_node = children[1] - join_condition = children[2] if len(children) > 2 else None - + left_node = join_node.left_table + right_node = join_node.right_table + join_condition = join_node.on_condition + using_columns = join_node.using + result = [] - - # Format left side (could be a table or nested join) + if left_node.type == NodeType.JOIN: - # Nested join - recursively format result.extend(format_join(left_node)) else: - # Simple table - this becomes the FROM table result.append(format_source(left_node)) - # Format the join itself join_dict = {} - # Map join types to mosql format join_type_map = { JoinType.JOIN: 'join', JoinType.INNER: 'inner join', @@ -203,17 +218,22 @@ def format_join(join_node: JoinNode) -> list: JoinType.RIGHT: 'right join', JoinType.FULL: 'full join', JoinType.CROSS: 'cross join', + JoinType.NATURAL: 'natural join', } join_key = join_type_map.get(join_node.join_type, 'join') join_dict[join_key] = format_source(right_node) - - # Add join condition if it exists + if join_condition: join_dict['on'] = format_expression(join_condition) - + if using_columns: + if len(using_columns) == 1: + join_dict['using'] = format_expression(using_columns[0]) + else: + join_dict['using'] = [format_expression(col) for col in using_columns] + result.append(join_dict) - + return result @@ -225,16 +245,21 @@ def format_source(node: Node) -> dict: subquery_child = list(node.children)[0] result = {'value': ast_to_json(subquery_child)} if node.alias: - result['name'] = node.alias + result['name'] = _render_alias(node.alias) + return result + elif node.type == NodeType.VAR: + result = {'value': f"__rv_{node.name}__"} + if node.alias: + result['name'] = _render_alias(node.alias) return result raise ValueError(f"Unsupported source type: {node.type}") def format_table(table_node: TableNode) -> dict: - """Format a table reference""" + """Format a table reference.""" result = {'value': table_node.name} if table_node.alias: - result['name'] = table_node.alias + result['name'] = _render_alias(table_node.alias) return result @@ -293,9 +318,21 @@ def format_order_by(order_by_node: OrderByNode) -> list: def format_expression(node: Node): """Format an expression node""" + if node.type == NodeType.VAR: + token = f"__rv_{node.name}__" + if node.parent_alias: + pa = _render_alias(node.parent_alias) + return f"{pa}.{token}" + return token + if node.type == NodeType.VARSET: + return f"__rvs_{node.name}__" + if node.type == NodeType.VAR_LITERAL: + return {'literal': f"{node.prefix}__rv_{node.name}__{node.suffix}"} + if node.type == NodeType.COLUMN: if node.parent_alias: - return f"{node.parent_alias}.{node.name}" + pa_token = _render_alias(node.parent_alias) + return f"{pa_token}.{node.name}" return node.name elif node.type == NodeType.LITERAL: @@ -408,5 +445,17 @@ def format_expression(node: Node): unit = node.unit.name.lower() return {'interval': [value, unit]} + elif node.type == NodeType.VAR: + return node.name + + elif node.type == NodeType.VARSET: + return node.name + + elif node.type == NodeType.QUERY: + return ast_to_json(node) + + elif node.type == NodeType.COMPOUND_QUERY: + return compound_to_mosql_json(node) + else: - raise ValueError(f"Unsupported node type in expression: {node.type}") \ No newline at end of file + raise ValueError(f"Unsupported node type in expression: {node.type}") diff --git a/core/query_parser.py b/core/query_parser.py index 10f6815..adf546a 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -8,6 +8,7 @@ ) # TODO: implement ElementVariableNode, SetVariableNode from core.ast.enums import JoinType, SortOrder +from typing import List, Optional import mo_sql_parsing as mosql import json @@ -133,8 +134,21 @@ def _append_source(node: Node, alias): if 'on' in item: on_condition = self.parse_expression(item['on'], aliases) + using_columns: Optional[List[Node]] = None + if 'using' in item: + using_value = item['using'] + if isinstance(using_value, list): + using_columns = [ + ColumnNode(str(c)) if not isinstance(c, dict) else self.parse_expression(c, aliases) + for c in using_value + ] + elif isinstance(using_value, dict): + using_columns = [self.parse_expression(using_value, aliases)] + else: + using_columns = [ColumnNode(str(using_value))] + join_type = self.parse_join_type(join_key) - join_node = JoinNode(left_source, right_source, join_type, on_condition) + join_node = JoinNode(left_source, right_source, join_type, on_condition, using_columns) left_source = join_node elif 'value' in item: @@ -595,7 +609,9 @@ def parse_join_type(join_key: str) -> JoinType: """Extract JoinType from mo_sql_parsing join key.""" key_lower = join_key.lower().replace(' ', '_') - if 'inner' in key_lower: + if 'natural' in key_lower: + return JoinType.NATURAL + elif 'inner' in key_lower: return JoinType.INNER elif 'left' in key_lower: return JoinType.LEFT diff --git a/core/query_rewriter_v2.py b/core/query_rewriter_v2.py index 38c8612..7a2cf63 100644 --- a/core/query_rewriter_v2.py +++ b/core/query_rewriter_v2.py @@ -23,7 +23,6 @@ import copy import logging -import re from contextlib import contextmanager from collections import deque from enum import Enum @@ -60,6 +59,7 @@ TableNode, TimeUnitNode, UnaryOperatorNode, + VariableLiteralNode, WhenThenNode, WhereNode, ) @@ -142,8 +142,6 @@ def _bind(var_name: str, value: Any, memo: dict) -> bool: # (table vars bind to whole TableNode; ColumnNode.parent_alias binds to string) if isinstance(existing, TableNode) and isinstance(value, str): return (existing.alias or existing.name) == value - if isinstance(value, TableNode) and isinstance(existing, str): - return (value.alias or value.name) == existing if isinstance(existing, Node) and isinstance(value, Node): return existing == value return existing == value @@ -151,10 +149,6 @@ def _bind(var_name: str, value: Any, memo: dict) -> bool: return True -def _is_var_name(s: Any, mapping: dict) -> bool: - """True if s is a string that is an external variable name in the rule mapping.""" - return isinstance(s, str) and s in mapping - # ============================================================================ # Core matching @@ -167,6 +161,34 @@ def _match_node( # --- variable nodes in pattern --- if isinstance(p, ElementVariableNode): + # Qualified column variable: ElementVariableNode with parent_alias that is a variable name + if isinstance(p.parent_alias, ElementVariableNode): + if not isinstance(q, ColumnNode): + return False + if q.parent_alias is None: + return False + if not _bind(p.name, q.name, memo): + return False + return _bind(p.parent_alias.name, q.parent_alias, memo) + # Table variable with variable alias: ElementVariableNode with alias that is a variable name + # e.g. ElementVariableNode("tb1", alias=ElementVariableNode("t1")) where "tb1" -> TableNode(name), "t1" -> alias string + # Bind a stripped TableNode (no alias) to p.name so that when p.name appears as a bare + # table variable in the rewrite (e.g. inner FROM ), it materializes as TableNode(name) + # without leaking the alias. The alias is separately captured via p.alias. + if isinstance(p.alias, ElementVariableNode): + if not isinstance(q, TableNode): + return False + if not _bind(p.name, TableNode(q.name), memo): + return False + if q.alias is None: + return False + return _bind(p.alias.name, q.alias, memo) + # Default: whole-node binding. + # JoinNode is a compound structural node (not an atomic value) that should never be + # bound to a bare element variable — it would violate the type contract and cause + # spurious second-pass matches after joins have been introduced. + if isinstance(q, JoinNode): + return False return _bind(p.name, q, memo) if isinstance(p, SetVariableNode): @@ -187,24 +209,30 @@ def _match_node( # --- type must be compatible --- if not isinstance(q, type(p)) and not isinstance(p, type(q)): - # Allow OperatorNode / UnaryOperatorNode subclass relationship - if not (isinstance(q, OperatorNode) and isinstance(p, OperatorNode)): - return False + # VariableLiteralNode matches against LiteralNode; allow it before the strict type guard + if not isinstance(p, VariableLiteralNode): + # Allow OperatorNode / UnaryOperatorNode subclass relationship + if not (isinstance(q, OperatorNode) and isinstance(p, OperatorNode)): + return False # --- leaf nodes --- + if isinstance(p, VariableLiteralNode): + if not isinstance(q, LiteralNode): + return False + qv = q.value + if not isinstance(qv, str): + return False + if p.prefix and not qv.startswith(p.prefix): + return False + if p.suffix and not qv.endswith(p.suffix): + return False + inner = qv[len(p.prefix): len(qv) - len(p.suffix) if p.suffix else len(qv)] + return _bind(p.name, LiteralNode(inner), memo) + if isinstance(p, LiteralNode): if not isinstance(q, LiteralNode): return False qv, pv = q.value, p.value - # RuleParserV2 may represent placeholders inside string literals like `''` - # as LiteralNode("s") (where "s" is a declared rule variable). In that case, - # treat it as a bindable placeholder rather than a concrete string. - - # TODO: We hope to further flatten variables in the literal, e.g., - # q: {like: [name, '%joe%']} -> Func(like, [Col('name'), LiteralNode('%joe%')]) - # p: {like: [x, '%y%']} -> Func(like, [EV(x),  LitrlComb([LiteralNode('%'), EV(y), LiteralNode('%')])]) - if isinstance(pv, str) and _is_var_name(pv, mapping): - return _bind(pv, q, memo) if isinstance(qv, str) and isinstance(pv, str): return qv.lower() == pv.lower() return qv == pv @@ -215,28 +243,20 @@ def _match_node( if isinstance(p, TimeUnitNode): return isinstance(q, TimeUnitNode) and q.name.upper() == p.name.upper() - # --- TableNode: name and alias may be variable names --- + # --- TableNode: name is always concrete; alias may be a variable name --- if isinstance(p, TableNode): if not isinstance(q, TableNode): return False - if _is_var_name(p.name, mapping) and p.alias is None: - # Variable stands for the entire table reference (name + alias). - # Bind to the whole TableNode so the rewrite can reproduce it faithfully. - return _bind(p.name, q, memo) - if _is_var_name(p.name, mapping): - if not _bind(p.name, q.name, memo): - return False - else: - if not isinstance(q.name, str) or q.name.lower() != p.name.lower(): - return False + if not isinstance(q.name, str) or q.name.lower() != p.name.lower(): + return False if p.alias is not None: # Pattern requires an alias (even if it's a variable). Do not match # unaliased tables, otherwise the alias var would bind to None and # rewrites expecting a real alias/identifier become nonsensical. if q.alias is None: return False - if _is_var_name(p.alias, mapping): - if not _bind(p.alias, q.alias, memo): + if isinstance(p.alias, ElementVariableNode): + if not _bind(p.alias.name, q.alias, memo): return False else: qa = q.alias or "" @@ -244,23 +264,19 @@ def _match_node( return False return True - # --- ColumnNode: name and parent_alias may be variable names --- + # --- ColumnNode: name is always concrete; parent_alias may be a variable name --- if isinstance(p, ColumnNode): if not isinstance(q, ColumnNode): return False - if _is_var_name(p.name, mapping): - if not _bind(p.name, q.name, memo): - return False - else: - if not isinstance(q.name, str) or q.name.lower() != p.name.lower(): - return False + if not isinstance(q.name, str) or q.name.lower() != p.name.lower(): + return False if p.parent_alias is not None: # Pattern requires a qualifier (even if it's a variable). Do not match # unqualified columns, otherwise the qualifier var would bind to None. if q.parent_alias is None: return False - if _is_var_name(p.parent_alias, mapping): - if not _bind(p.parent_alias, q.parent_alias, memo): + if isinstance(p.parent_alias, ElementVariableNode): + if not _bind(p.parent_alias.name, q.parent_alias, memo): return False else: qpa = q.parent_alias or "" @@ -315,6 +331,14 @@ def _match_node( return False if q.name.upper() != p.name.upper(): return False + if p.alias is not None: + if q.alias is None: + return False + if isinstance(p.alias, ElementVariableNode): + if not _bind(p.alias.name, q.alias, memo): + return False + elif isinstance(q.alias, str) and q.alias.lower() != p.alias.lower(): + return False return _match_children_list(list(q.children), list(p.children), memo, mode, mapping) # --- ListNode --- @@ -365,16 +389,16 @@ def _match_node( if isinstance(p, LimitNode): if not isinstance(q, LimitNode): return False - if isinstance(p.limit, str) and _is_var_name(p.limit, mapping): - return _bind(p.limit, q.limit, memo) + if isinstance(p.limit, ElementVariableNode): + return _bind(p.limit.name, q.limit, memo) return q.limit == p.limit # --- OffsetNode --- if isinstance(p, OffsetNode): if not isinstance(q, OffsetNode): return False - if isinstance(p.offset, str) and _is_var_name(p.offset, mapping): - return _bind(p.offset, q.offset, memo) + if isinstance(p.offset, ElementVariableNode): + return _bind(p.offset.name, q.offset, memo) return q.offset == p.offset # --- JoinNode --- @@ -631,11 +655,15 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: """Convert an element-variable binding into a concrete AST node. Rules: - - If bound to a Node, return it (but strip `.alias` to avoid leaking output aliases - unless the rewrite explicitly carries them). + - If bound to a TableNode, return it directly (table aliases must be preserved + so that qualified column references like e1.col remain valid in the rewrite). + - If bound to any other Node, return it (but strip `.alias` to avoid leaking + output aliases unless the rewrite explicitly carries them). - If bound to scalar identifiers, materialize as ColumnNode/LiteralNode so the formatter can emit SQL. """ + if isinstance(val, TableNode): + return val if isinstance(val, Node): if hasattr(val, "alias") and getattr(val, "alias") is not None: cloned = copy.deepcopy(val) @@ -644,13 +672,39 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: return val if isinstance(val, str): return ColumnNode(val) - if isinstance(val, (int, float, bool)) or val is None: - return LiteralNode(val) - # Fallback: caller will keep the variable node unchanged. return None if isinstance(node, ElementVariableNode): val = memo.get(node.name, node) + # If variable has a parent_alias that is a variable name, reconstruct qualified column + # e.g. ElementVariableNode("a1", parent_alias=ElementVariableNode("t1")) where "a1" -> "id", "t1" -> "e1" + # produces ColumnNode("id", _parent_alias="e1") + # The node's own alias (if any) is a literal string from the rewrite template; preserve it. + pa_key = node.parent_alias.name if isinstance(node.parent_alias, ElementVariableNode) else node.parent_alias + if pa_key is not None and pa_key in memo: + pa_val = memo[pa_key] + col_name = val if isinstance(val, str) else (val.name if isinstance(val, Node) and hasattr(val, 'name') else None) + if col_name is not None: + if isinstance(pa_val, TableNode): + pa_str = pa_val.alias if pa_val.alias is not None else pa_val.name + elif isinstance(pa_val, str): + pa_str = pa_val + else: + pa_str = None + if pa_str is not None: + return ColumnNode(col_name, _alias=node.alias, _parent_alias=pa_str) + # If variable has an alias that is a variable name, reconstruct a TableNode + # e.g. ElementVariableNode("tb1", alias=ElementVariableNode("t1")) where "tb1" -> TableNode("employee", ...), "t1" -> "e1" + # produces TableNode("employee", "e1") + alias_key = node.alias.name if isinstance(node.alias, ElementVariableNode) else node.alias + if alias_key is not None and alias_key in memo: + alias_val = memo[alias_key] + # val may be a whole TableNode (from binding) or a string + table_name = val.name if isinstance(val, TableNode) else None + if table_name is not None: + alias_str = alias_val if isinstance(alias_val, str) else None + if alias_str is not None: + return TableNode(table_name, alias_str) materialized = _materialize_element_binding(val) if materialized is not None: return materialized @@ -660,33 +714,19 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: # Should not appear at this level; caller handles list expansion return node + if isinstance(node, VariableLiteralNode): + bound = memo.get(node.name) + if bound is None: + return node + val = bound.value if isinstance(bound, LiteralNode) else str(bound) + return LiteralNode(f"{node.prefix}{val}{node.suffix}") + if isinstance(node, (LiteralNode, DataTypeNode, TimeUnitNode)): - # For string literals, substitute any variable names embedded in the value - # (e.g. LiteralNode('%y%') with memo['y']=LiteralNode('iphone') to LiteralNode('%iphone%')) - if isinstance(node, LiteralNode) and isinstance(node.value, str): - new_val = node.value - for var_name, bound in memo.items(): - if not isinstance(var_name, str) or var_name.startswith("_"): - continue - if var_name not in new_val: - continue - if isinstance(bound, LiteralNode) and isinstance(bound.value, (str, int, float)): - new_val = re.sub(r"\b" + re.escape(var_name) + r"\b", str(bound.value), new_val) - elif isinstance(bound, str): - new_val = re.sub(r"\b" + re.escape(var_name) + r"\b", bound, new_val) - if new_val != node.value: - return LiteralNode(new_val) return node if isinstance(node, TableNode): - # If the name variable is bound to a whole TableNode, return it directly - if isinstance(node.name, str) and node.name in memo: - val = memo[node.name] - if isinstance(val, TableNode): - return val - new_name = _subst_str(node.name, memo) new_alias = _subst_str(node.alias, memo) if node.alias is not None else None - return TableNode(new_name, new_alias) + return TableNode(node.name, new_alias) if isinstance(node, ColumnNode): new_name = _subst_str(node.name, memo) @@ -741,16 +781,16 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: if isinstance(node, LimitNode): val = node.limit - if isinstance(val, str): - val = memo.get(val, val) + if isinstance(val, ElementVariableNode): + val = memo.get(val.name, val) if isinstance(val, LiteralNode): val = val.value return LimitNode(val) if isinstance(node, OffsetNode): val = node.offset - if isinstance(val, str): - val = memo.get(val, val) + if isinstance(val, ElementVariableNode): + val = memo.get(val.name, val) if isinstance(val, LiteralNode): val = val.value return OffsetNode(val) @@ -795,7 +835,7 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: def _subst_str(s: Any, memo: dict) -> Any: - """Substitute a string field if it matches a variable name in memo. + """Substitute a string or ElementVariableNode alias/parent_alias field from memo. Extracts a canonical string from the bound value: - str to return directly @@ -804,9 +844,11 @@ def _subst_str(s: Any, memo: dict) -> Any: - FunctionNode (rare): if an element variable bound to e.g. COUNT(col) in the SELECT list is substituted into a string-only context, unwrap COUNT(col) -> ``col`` name. Normally bindings are ColumnNode/TableNode/str here. + ElementVariableNode in alias/parent_alias fields (from widened field type) looks up .name in memo. """ - if isinstance(s, str) and s in memo: - val = memo[s] + key = s.name if isinstance(s, ElementVariableNode) else s + if isinstance(key, str) and key in memo: + val = memo[key] if isinstance(val, str): return val if isinstance(val, TableNode): @@ -851,7 +893,7 @@ def _replace_in_tree(tree: Node, target_id: int, replacement: Node) -> Node: return replacement if isinstance(tree, (LiteralNode, DataTypeNode, TimeUnitNode, TableNode, ColumnNode, - ElementVariableNode, SetVariableNode)): + ElementVariableNode, SetVariableNode, VariableLiteralNode)): return tree if isinstance(tree, FunctionNode): @@ -974,7 +1016,7 @@ def _node_subst(tree: Any, src: Any, tgt: Any) -> Any: if isinstance(tree, TableNode): return TableNode(_subst_val(tree.name, src, tgt), _subst_val(tree.alias, src, tgt)) - if isinstance(tree, (LiteralNode, DataTypeNode, TimeUnitNode)): + if isinstance(tree, (LiteralNode, DataTypeNode, TimeUnitNode, VariableLiteralNode)): return tree if isinstance(tree, FunctionNode): diff --git a/core/rule.py b/core/rule.py new file mode 100644 index 0000000..f95b743 --- /dev/null +++ b/core/rule.py @@ -0,0 +1,52 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional +from core.ast.node import Node + + +@dataclass +class RuleV2: + pattern: str + rewrite: str + pattern_ast: Node + rewrite_ast: Node + mapping: Dict[str, str] + source_pattern_ast: Optional[Node] = None + source_rewrite_ast: Optional[Node] = None + source_pattern_sql: str = "" + source_rewrite_sql: str = "" + constraints: str = "" + actions: str = "" + actions_json: List[Any] = field(default_factory=list) + id: Optional[Any] = None + key: Optional[str] = None + children: Optional[List[RuleV2]] = field(default=None) + + @classmethod + def from_dict(cls, d: dict) -> RuleV2: + return cls( + pattern=d.get("pattern", ""), + rewrite=d.get("rewrite", ""), + pattern_ast=d["pattern_ast"], + rewrite_ast=d["rewrite_ast"], + mapping=d.get("mapping", {}), + source_pattern_ast=d.get("source_pattern_ast"), + source_rewrite_ast=d.get("source_rewrite_ast"), + source_pattern_sql=d.get("source_pattern_sql", ""), + source_rewrite_sql=d.get("source_rewrite_sql", ""), + constraints=d.get("constraints", ""), + actions=d.get("actions", ""), + actions_json=d.get("actions_json", []) or [], + id=d.get("id"), + key=d.get("key"), + children=d.get("children"), + ) + + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + def __setitem__(self, key: str, value: Any) -> None: + setattr(self, key, value) + + def get(self, key: str, default: Any = None) -> Any: + return getattr(self, key, default) diff --git a/core/rule_generator_v2.py b/core/rule_generator_v2.py new file mode 100644 index 0000000..8d903cf --- /dev/null +++ b/core/rule_generator_v2.py @@ -0,0 +1,2118 @@ +"""AST-based rule generation helpers. + +Rule dict produced by this module: + { + "pattern": str, + "rewrite": str, + "pattern_ast": Node, + "rewrite_ast": Node, + "source_pattern_ast": Node, + "source_rewrite_ast": Node, + "source_pattern_sql": str, + "source_rewrite_sql": str, + "mapping": dict, # external variable name -> internal parser token + "constraints": str, + "actions": str, + } + +The generator starts from a concrete pattern and rewrite pair, then derives +more general rules by replacing matching tables, columns, literals, subtrees, +variable lists, and droppable branches with rule variables. Public methods keep +the rule dict shape stable while the private helpers do AST-specific traversal, +replacement, and formatting cleanup. +""" + +from __future__ import annotations + +import copy +import functools +import numbers +import re +from collections import defaultdict, deque +from typing import Dict, Iterator, List, Optional, Set, Tuple, Union + +from core.ast.enums import NodeType +from core.ast.node import ( + CaseNode, + ColumnNode, + CompoundQueryNode, + ElementVariableNode, + FromNode, + FunctionNode, + GroupByNode, + HavingNode, + JoinNode, + LimitNode, + ListNode, + LiteralNode, + Node, + OffsetNode, + OrderByItemNode, + OrderByNode, + OperatorNode, + QueryNode, + SelectNode, + SetVariableNode, + SubqueryNode, + TableNode, + UnaryOperatorNode, + VariableLiteralNode, + WhenThenNode, + WhereNode, +) +from core.query_parser import QueryParser +from core.query_formatter import QueryFormatter +from core.rule import RuleV2 +from core.rule_parser_v2 import RuleParserV2, Scope, VarType, VarTypesInfo + + +@functools.lru_cache(maxsize=None) +def _lev_distance_cached(a: str, b: str) -> int: + if not b: + return len(a) + if not a: + return len(b) + if a[0] == b[0]: + return _lev_distance_cached(a[1:], b[1:]) + return 1 + min( + _lev_distance_cached(a[1:], b), + _lev_distance_cached(a, b[1:]), + _lev_distance_cached(a[1:], b[1:]), + ) + + +class RuleGeneratorV2: + """Generate AST-backed rewrite rules from example SQL pairs.""" + + _MAX_RECOMMENDATION_CANDIDATES = 256 # BFS cap to bound graph exploration cost + + @staticmethod + def parse_validate_single(query: str) -> Tuple[bool, str, int]: + """Validate a standalone rule query (used when only one half of a rule is being edited). + + Returns (ok, message, error_index) where error_index is the character offset of the first parse error, or 0 on success. + """ + return RuleGeneratorV2._parse_validate_impl(query, None) + + @staticmethod + def parse_validate(pattern: str, rewrite: str) -> Tuple[bool, str, int]: + """Validate a (pattern, rewrite) rule pair and return (ok, message, error_index). + + Reports bracket mismatches, parser errors on either side, and rejects rules whose rewrite uses a variable that never appears in the pattern. + """ + return RuleGeneratorV2._parse_validate_impl(pattern, rewrite) + + @staticmethod + def _parse_validate_impl(pattern: str, rewrite: Optional[str]) -> Tuple[bool, str, int]: + scope_names = { + Scope.SELECT: "SELECT", + Scope.FROM: "FROM", + Scope.WHERE: "WHERE", + Scope.CONDITION: "CONDITION", + } + scope_prefix_lengths = { + Scope.SELECT: 0, + Scope.FROM: 9, + Scope.WHERE: 16, + Scope.CONDITION: 22, + } + + wrong_bracket_pattern = RuleParserV2.find_malformed_brackets(pattern) + if wrong_bracket_pattern > -1: + return False, "mismatching brackets in query 1", wrong_bracket_pattern + if rewrite is not None: + wrong_bracket_rewrite = RuleParserV2.find_malformed_brackets(rewrite) + if wrong_bracket_rewrite > -1: + return False, "mismatching brackets in query 2", wrong_bracket_rewrite + + pattern_compact = " ".join(pattern.splitlines()) + rewrite_compact = " ".join(rewrite.splitlines()) if rewrite is not None else None + + def _first_token(sql: str) -> str: + parts = [part for part in sql.split(" ") if part] + return parts[0] if parts else "" + + for keyword in ("SELECT", "FROM", "WHERE"): + token = _first_token(pattern_compact) + if token and RuleGeneratorV2._lev_distance(keyword, token) == 1: + return False, f"possible spelling error at query 1: {token} instead of {keyword}", 0 + if rewrite_compact is not None: + token = _first_token(rewrite_compact) + if token and RuleGeneratorV2._lev_distance(keyword, token) == 1: + return False, f"possible spelling error at query 2: {token} instead of {keyword}", 0 + + try: + pattern_sql, rewrite_sql, mapping = RuleParserV2.replaceVars(pattern_compact, rewrite_compact or pattern_compact) + pattern_full, pattern_scope = RuleParserV2.extendToFullSQL(pattern_sql) + QueryParser().parse(pattern_full) + except Exception as e: + message = str(e) + display_message = RuleGeneratorV2.dereplaceVars(message, mapping) + match = re.search(r'[Ee]xpecting(.*)found "(.*)" \(at char (\d+)', display_message) + if match: + error_index = RuleGeneratorV2._rule_fragment_error_index( + int(match.group(3)), + pattern_scope, + pattern_full, + mapping, + scope_prefix_lengths, + ) + return ( + False, + "Error in first query, current Scope is " + + scope_names[pattern_scope] + + " if that is not intended check spelling at index 0. Expecting " + + match.group(1).strip() + + " found " + + match.group(2).strip(), + error_index, + ) + return False, message, -1 + + if rewrite is None: + return True, "success", 0 + + # Variables that appear only in rewrite can never be instantiated from pattern. + pattern_vars = set(re.findall(r"<<\w+>>|<\w+>", pattern)) + for match in re.finditer(r"<<\w+>>|<\w+>", rewrite): + if match.group(0) not in pattern_vars: + return False, f"{match.group(0)} not in first rule", match.start() + + try: + rewrite_full, rewrite_scope = RuleParserV2.extendToFullSQL(rewrite_sql) + QueryParser().parse(rewrite_full) + return True, "Success", 0 + except Exception as e: + message = str(e) + display_message = RuleGeneratorV2.dereplaceVars(message, mapping) + match = re.search(r'[Ee]xpecting(.*)found "(.*)" \(at char (\d+)', display_message) + if match: + error_index = RuleGeneratorV2._rule_fragment_error_index( + int(match.group(3)), + rewrite_scope, + rewrite_full, + mapping, + scope_prefix_lengths, + ) + return ( + False, + "Error in second query, current Scope is " + + scope_names[rewrite_scope] + + " if that is not intended check spelling at index 0. Expecting " + + match.group(1).strip() + + " found " + + match.group(2).strip(), + error_index, + ) + return False, message, -1 + + @staticmethod + def _lev_distance(a: str, b: str) -> int: + return _lev_distance_cached(a, b) + + @staticmethod + def _rule_fragment_error_index( + parser_char_index: int, + scope: Scope, + full_sql: str, + mapping: Dict[str, str], + scope_prefix_lengths: Dict[Scope, int], + ) -> int: + """Translate a parser error offset from wrapped SQL back to the rule fragment. + + Validation parses fragments after wrapping them into complete SQL and + replacing user placeholders with parser-safe internal variable tokens. + The returned index points at the user's original fragment. + """ + error_index = parser_char_index - scope_prefix_lengths[scope] + prefix = full_sql[:parser_char_index] + for internal_name in mapping.values(): + diff = RuleGeneratorV2._internal_variable_token_length_delta(internal_name) + if diff <= 0: + continue + error_index -= prefix.count(internal_name) * diff + return error_index + + @staticmethod + def _internal_variable_token_length_delta(internal_name: str) -> int: + if internal_name.startswith(VarTypesInfo[VarType.ElementVariable]["internalBase"]): + display_token = "V" + internal_name[len(VarTypesInfo[VarType.ElementVariable]["internalBase"]):] + return len(internal_name) - len(display_token) + if internal_name.startswith(VarTypesInfo[VarType.SetVariable]["internalBase"]): + display_token = "VL" + internal_name[len(VarTypesInfo[VarType.SetVariable]["internalBase"]):] + return len(internal_name) - len(display_token) + return 0 + + @staticmethod + def initialize_seed_rule(q0: str, q1: str) -> RuleV2: + """Build the initial (un-generalized) rule for the rewrite pair q0 -> q1. + + Parses both sides via RuleParserV2, snapshots the source ASTs/SQL, and returns a fresh RuleV2 carrying pattern, rewrite, pattern_ast, rewrite_ast, mapping, and empty constraints/actions. + """ + parsed = RuleParserV2.parse(q0, q1) + pattern = RuleGeneratorV2.deparse(copy.deepcopy(parsed.pattern_ast)) + rewrite = RuleGeneratorV2.deparse(copy.deepcopy(parsed.rewrite_ast)) + return RuleV2( + pattern=pattern, + rewrite=rewrite, + pattern_ast=parsed.pattern_ast, + rewrite_ast=parsed.rewrite_ast, + source_pattern_ast=copy.deepcopy(parsed.pattern_ast), + source_rewrite_ast=copy.deepcopy(parsed.rewrite_ast), + source_pattern_sql=q0, + source_rewrite_sql=q1, + mapping=parsed.mapping, + constraints="", + actions="", + ) + + RuleGeneralizations = ( + "generalize_tables", + "generalize_columns", + "generalize_literals", + "generalize_subtrees", + "generalize_variables", + "generalize_branches", + ) + + @staticmethod + def generate_general_rule(q0: str, q1: str) -> RuleV2: + """Repeatedly apply every generalize_* step until the rule's fingerprint stops changing. + + Returns the most general rule reachable from the seed by exhaustively variablizing tables/columns/literals/subtrees, merging variable lists, and dropping branches. + """ + seed_rule = RuleGeneratorV2.initialize_seed_rule(q0, q1) + general_rule = seed_rule + visited_fingerprints: Set[str] = set() + rule_fingerprint = RuleGeneratorV2.fingerPrint(general_rule) + while rule_fingerprint not in visited_fingerprints: + visited_fingerprints.add(rule_fingerprint) + for generalization in RuleGeneratorV2.RuleGeneralizations: + general_rule = getattr(RuleGeneratorV2, generalization)(general_rule) + rule_fingerprint = RuleGeneratorV2.fingerPrint(general_rule) + return general_rule + + @staticmethod + def generate_rule_graph(q0: str, q1: str) -> RuleV2: + """Build the full BFS graph of generalizations rooted at the seed rule for q0 -> q1. + + Each node's children list is populated with the rules reachable in one variabilization/merge/drop step; nodes with the same fingerprint are deduplicated, so the graph is a DAG, not a tree. + """ + seed_rule = RuleGeneratorV2.initialize_seed_rule(q0, q1) + seed_fp = RuleGeneratorV2.fingerPrint(seed_rule) + visited = {seed_fp: seed_rule} + queue: deque[RuleV2] = deque([seed_rule]) + while queue: + base_rule = queue.popleft() + base_rule["children"] = [] + for transform in ( + RuleGeneratorV2.variablize_tables, + RuleGeneratorV2.variablize_columns, + RuleGeneratorV2.variablize_literals, + RuleGeneratorV2.variablize_subtrees, + RuleGeneratorV2.merge_variables, + RuleGeneratorV2.drop_branches, + ): + for child_rule in transform(base_rule): + child_fp = RuleGeneratorV2.fingerPrint(child_rule) + if child_fp not in visited: + visited[child_fp] = child_rule + queue.append(child_rule) + base_rule["children"].append(child_rule) + else: + base_rule["children"].append(visited[child_fp]) + return seed_rule + + @staticmethod + def recommend_simple_rules(examples: List[Dict[str, str]]) -> List[RuleV2]: + """Pick a small set of generalized rules that together cover every (q0, q1) example. + + Generates candidate rules per example, fingerprints them, and greedy set-covers the still-uncovered examples, breaking ties toward fewer variables. + """ + fingerprint_to_examples: Dict[str, Set[int]] = defaultdict(set) + fingerprint_to_rule: Dict[str, RuleV2] = {} + example_candidates: List[List[Tuple[str, RuleV2]]] = [] + + for index, example in enumerate(examples): + seed = RuleGeneratorV2.initialize_seed_rule(example["q0"], example["q1"]) + candidates_with_fingerprints: List[Tuple[str, RuleV2]] = [] + for rule in RuleGeneratorV2._recommendation_candidates(seed): + fp = RuleGeneratorV2.fingerPrint(rule) + candidates_with_fingerprints.append((fp, rule)) + fingerprint_to_examples[fp].add(index) + current = fingerprint_to_rule.get(fp) + if current is None or RuleGeneratorV2.numberOfVariables(rule) < RuleGeneratorV2.numberOfVariables(current): + fingerprint_to_rule[fp] = rule + example_candidates.append(candidates_with_fingerprints) + + uncovered = set(range(len(examples))) + ans: List[RuleV2] = [] + for index, _example in enumerate(examples): + if index not in uncovered: + continue + chosen: Optional[RuleV2] = None + remaining = set(uncovered) + for fp, rule in example_candidates[index]: + covered = fingerprint_to_examples.get(fp, set()).intersection(remaining) + if not covered: + continue + remaining -= covered + chosen = fingerprint_to_rule.get(fp, rule) + if not remaining: + break + if chosen is not None: + uncovered = remaining + ans.append(chosen) + return ans + + @staticmethod + def _recommendation_signature(rule: RuleV2) -> str: + pattern_ast = rule.get("pattern_ast") + rewrite_ast = rule.get("rewrite_ast") + if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): + raise TypeError("rule ASTs must be Node instances") + state = { + "tables": {}, + "aliases": {}, + } + pattern_sig = RuleGeneratorV2._recommendation_ast_signature(pattern_ast, state) + rewrite_sig = RuleGeneratorV2._recommendation_ast_signature(rewrite_ast, state) + return repr((pattern_sig, rewrite_sig)) + + @staticmethod + def _recommendation_ast_signature(node: Optional[Node], state: Dict[str, Dict[str, str]]) -> object: + if node is None: + return None + + def _table_token(name: Optional[str]) -> Optional[str]: + if name is None: + return None + mapped = state["tables"].get(name) + if mapped is None: + mapped = f"T{len(state['tables']) + 1}" + state["tables"][name] = mapped + return mapped + + def _alias_token(name) -> Optional[str]: + if name is None: + return None + key = name.name if isinstance(name, ElementVariableNode) else name + mapped = state["aliases"].get(key) + if mapped is None: + mapped = f"A{len(state['aliases']) + 1}" + state["aliases"][key] = mapped + return mapped + + if isinstance(node, QueryNode): + return ("QUERY", tuple(RuleGeneratorV2._recommendation_ast_signature(child, state) for child in node.children)) + if isinstance(node, SelectNode): + distinct_on = ( + RuleGeneratorV2._recommendation_ast_signature(node.distinct_on, state) + if node.distinct_on is not None + else None + ) + items = [RuleGeneratorV2._recommendation_ast_signature(child, state) for child in node.children] + return ("SELECT", node.distinct, distinct_on, tuple(items)) + if isinstance(node, FromNode): + return ("FROM", tuple(RuleGeneratorV2._recommendation_ast_signature(child, state) for child in node.children)) + if isinstance(node, WhereNode): + return ("WHERE", tuple(RuleGeneratorV2._recommendation_ast_signature(child, state) for child in node.children)) + if isinstance(node, GroupByNode): + return ("GROUPBY", tuple(RuleGeneratorV2._recommendation_ast_signature(child, state) for child in node.children)) + if isinstance(node, HavingNode): + return ("HAVING", tuple(RuleGeneratorV2._recommendation_ast_signature(child, state) for child in node.children)) + if isinstance(node, OrderByNode): + return ("ORDERBY", tuple(RuleGeneratorV2._recommendation_ast_signature(child, state) for child in node.children)) + if isinstance(node, OrderByItemNode): + inner = list(node.children)[0] if node.children else None + return ("ORDERBY_ITEM", node.sort.value if node.sort else None, RuleGeneratorV2._recommendation_ast_signature(inner, state)) + if isinstance(node, LimitNode): + value = node.limit + if isinstance(value, ElementVariableNode): + value = f"VAR:{RuleGeneratorV2._fingerPrint(value.name)}" + return ("LIMIT", value) + if isinstance(node, OffsetNode): + value = node.offset + if isinstance(value, ElementVariableNode): + value = f"VAR:{RuleGeneratorV2._fingerPrint(value.name)}" + return ("OFFSET", value) + if isinstance(node, TableNode): + return ("TABLE", _table_token(node.name), _alias_token(node.alias)) + if isinstance(node, SubqueryNode): + inner = list(node.children)[0] if node.children else None + return ("SUBQUERY", _alias_token(node.alias), RuleGeneratorV2._recommendation_ast_signature(inner, state)) + if isinstance(node, ColumnNode): + # Variablized columns are now ElementVariableNode — ColumnNode always has a concrete name. + return ("COLUMN", node.name, _alias_token(node.alias), _alias_token(node.parent_alias)) + if isinstance(node, LiteralNode): + return ("LITERAL", node.value, _alias_token(getattr(node, "alias", None))) + if isinstance(node, VariableLiteralNode): + return ("VAR_LITERAL", node.prefix, node.suffix) + if isinstance(node, FunctionNode): + return ( + "FUNCTION", + node.name, + _alias_token(node.alias), + tuple(RuleGeneratorV2._recommendation_ast_signature(child, state) for child in node.children), + ) + if isinstance(node, JoinNode): + children = list(node.children) + return ( + "JOIN", + node.join_type.value, + tuple(RuleGeneratorV2._recommendation_ast_signature(child, state) for child in children), + ) + if isinstance(node, UnaryOperatorNode): + child = list(node.children)[0] if node.children else None + return ("UNARY", node.name, RuleGeneratorV2._recommendation_ast_signature(child, state)) + if isinstance(node, OperatorNode): + return ( + "OP", + node.name, + tuple(RuleGeneratorV2._recommendation_ast_signature(child, state) for child in node.children), + ) + if isinstance(node, ElementVariableNode): + return ("EVAR", f"VAR:{RuleGeneratorV2._fingerPrint(node.name)}", _alias_token(node.parent_alias)) + if isinstance(node, SetVariableNode): + return ("SVAR", RuleGeneratorV2._fingerPrint(node.name)) + if isinstance(node, CompoundQueryNode): + return ( + "COMPOUND", + node.is_all, + RuleGeneratorV2._recommendation_ast_signature(node.left, state), + RuleGeneratorV2._recommendation_ast_signature(node.right, state), + ) + return ( + type(node).__name__, + tuple(RuleGeneratorV2._recommendation_ast_signature(child, state) for child in getattr(node, "children", [])), + ) + + @staticmethod + def _recommendation_candidates(seed: RuleV2) -> List[RuleV2]: + candidates: List[RuleV2] = [] + seed_sig = RuleGeneratorV2._recommendation_signature(seed) + seen: Set[str] = {seed_sig} + queue: deque[RuleV2] = deque([seed]) + max_candidates = RuleGeneratorV2._MAX_RECOMMENDATION_CANDIDATES + + while queue and len(candidates) < max_candidates: + base_rule = queue.popleft() + for transform in ( + RuleGeneratorV2.variablize_tables, + RuleGeneratorV2.variablize_columns, + RuleGeneratorV2.variablize_literals, + RuleGeneratorV2.variablize_subtrees, + RuleGeneratorV2.merge_variables, + RuleGeneratorV2.drop_branches, + ): + for child in transform(base_rule): + sig = RuleGeneratorV2._recommendation_signature(child) + if sig in seen: + continue + seen.add(sig) + candidates.append(child) + queue.append(child) + if len(candidates) >= max_candidates: + break + if len(candidates) >= max_candidates: + break + return candidates + + @staticmethod + def variablize_tables(rule: RuleV2) -> List[RuleV2]: + """Return one child rule per table that can still be replaced with a fresh element variable. + + Each child is the result of substituting a single table reference with on both pattern and rewrite sides. + """ + pattern_ast = rule.pattern_ast + rewrite_ast = rule.rewrite_ast + if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): + raise TypeError("rule ASTs must be Node instances") + return [RuleGeneratorV2.variablize_table(rule, table) for table in RuleGeneratorV2.tables(pattern_ast, rewrite_ast)] + + @staticmethod + def _sync_rule_strings(rule: RuleV2) -> None: + rule.pattern = RuleGeneratorV2.deparse(rule.pattern_ast) + rule.rewrite = RuleGeneratorV2.deparse(rule.rewrite_ast) + + @staticmethod + def variablize_table(rule: Union[RuleV2, dict], table: Dict[str, str]) -> RuleV2: + """Return a new rule where the named table (and its qualified column refs) is replaced by a fresh element variable. + + table is a {"value": , "name": } descriptor as produced by tables. Both ASTs are rewritten and re-deparsed; the input rule is not mutated. + """ + if isinstance(rule, dict): + rule = RuleV2.from_dict(rule) + new_rule = copy.deepcopy(rule) + mapping = copy.deepcopy(new_rule.mapping) + if not isinstance(mapping, dict): + raise TypeError("rule.mapping must be a dict[str, str]") + + target_value = table.get("value") + target_name = table.get("name") + if not isinstance(target_value, str) or not isinstance(target_name, str): + raise TypeError("table must have string keys 'value' and 'name'") + + mapping, external_name = RuleGeneratorV2._find_next_element_variable(mapping) + new_rule.mapping = mapping + + for attr in ("pattern_ast", "rewrite_ast"): + ast = getattr(new_rule, attr) + if not isinstance(ast, Node): + raise TypeError(f"rule.{attr} must be an AST Node") + setattr(new_rule, attr, RuleGeneratorV2._replace_table_in_ast( + ast, + target_value=target_value, + target_name=target_name, + placeholder_token=external_name, + )) + + RuleGeneratorV2._sync_rule_strings(new_rule) + return new_rule + + @staticmethod + def variablize_columns(rule: RuleV2) -> List[RuleV2]: + """Return one child rule per column that can still be replaced with a fresh element variable. + + Each child substitutes one un-variablized column name with on both sides. + """ + pattern_ast = rule.get("pattern_ast") + rewrite_ast = rule.get("rewrite_ast") + if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): + raise TypeError("rule ASTs must be Node instances") + return [RuleGeneratorV2.variablize_column(rule, column) for column in RuleGeneratorV2.columns(pattern_ast, rewrite_ast)] + + @staticmethod + def variablize_column(rule: RuleV2, column: str) -> RuleV2: + """Return a new rule where every occurrence of column (in both ASTs) is replaced by a fresh element variable. + + Allocates the next available and re-deparses both sides. The input rule is not mutated. + """ + new_rule = copy.deepcopy(rule) + mapping = copy.deepcopy(new_rule["mapping"]) + if not isinstance(mapping, dict): + raise TypeError("rule['mapping'] must be a dict[str, str]") + + mapping, external_name = RuleGeneratorV2._find_next_element_variable(mapping) + new_rule["mapping"] = mapping + + for key in ("pattern_ast", "rewrite_ast"): + ast = new_rule.get(key) + if not isinstance(ast, Node): + raise TypeError(f"rule['{key}'] must be an AST Node") + new_rule[key] = RuleGeneratorV2._replace_column_in_ast(ast, column, external_name) + + RuleGeneratorV2._sync_rule_strings(new_rule) + return new_rule + + @staticmethod + def variablize_literals(rule: RuleV2) -> List[Dict[str, object]]: + """Return one child rule per literal that can still be replaced with a fresh element variable. + + Considers literals that recur within one side or are shared across both sides. + """ + pattern_ast = rule.get("pattern_ast") + rewrite_ast = rule.get("rewrite_ast") + if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): + raise TypeError("rule ASTs must be Node instances") + return [RuleGeneratorV2.variablize_literal(rule, literal) for literal in RuleGeneratorV2.literals(pattern_ast, rewrite_ast)] + + @staticmethod + def variablize_literal(rule: RuleV2, literal: Union[str, numbers.Number]) -> RuleV2: + """Return a new rule where every occurrence of literal (in both ASTs) is replaced by a fresh element variable. + + Allocates the next available and re-deparses both sides. The input rule is not mutated. + """ + new_rule = copy.deepcopy(rule) + mapping = copy.deepcopy(new_rule["mapping"]) + if not isinstance(mapping, dict): + raise TypeError("rule['mapping'] must be a dict[str, str]") + + mapping, external_name = RuleGeneratorV2._find_next_element_variable(mapping) + new_rule["mapping"] = mapping + + for key in ("pattern_ast", "rewrite_ast"): + ast = new_rule.get(key) + if not isinstance(ast, Node): + raise TypeError(f"rule['{key}'] must be an AST Node") + new_rule[key] = RuleGeneratorV2._replace_literal_in_ast(ast, literal, external_name) + + RuleGeneratorV2._sync_rule_strings(new_rule) + return new_rule + + @staticmethod + def variablize_subtrees(rule: RuleV2) -> List[Dict[str, object]]: + """Return one child rule per subtree shared by pattern and rewrite that can be collapsed into an element variable. + """ + return [RuleGeneratorV2.variablize_subtree(rule, subtree) for subtree in RuleGeneratorV2.subtrees(rule.pattern_ast, rule.rewrite_ast)] + + @staticmethod + def variablize_subtree(rule: RuleV2, subtree: Node) -> RuleV2: + """Return a new rule where every occurrence of subtree (in both ASTs) is replaced by a fresh variable. + + A list subtree (the operands of an ``IN (...)`` clause) collapses into a set + variable ``<>`` so it matches a comma-separated list of arbitrary length, + matching the `` IN (<>)`` convention. Every other subtree collapses into + an element variable ````. The input rule is not mutated. + """ + new_rule = copy.deepcopy(rule) + mapping = copy.deepcopy(new_rule["mapping"]) + if not isinstance(mapping, dict): + raise TypeError("rule['mapping'] must be a dict[str, str]") + + if isinstance(subtree, ListNode): + mapping, external_name = RuleGeneratorV2._find_next_set_variable(mapping) + replacement: Node = SetVariableNode(external_name) + else: + mapping, external_name = RuleGeneratorV2._find_next_element_variable(mapping) + replacement = ElementVariableNode(external_name) + new_rule["mapping"] = mapping + + for key in ("pattern_ast", "rewrite_ast"): + ast = new_rule.get(key) + if not isinstance(ast, Node): + raise TypeError(f"rule['{key}'] must be an AST Node") + new_rule[key] = RuleGeneratorV2._replace_subtree_in_ast(ast, subtree, copy.deepcopy(replacement)) + + RuleGeneratorV2._sync_rule_strings(new_rule) + return new_rule + + @staticmethod + def merge_variables(rule: RuleV2) -> List[Dict[str, object]]: + """Return one child rule per element-variable list collapsible into a single set variable <>. + + Each candidate list is the intersection of an AND-chain or SELECT-list on both sides. + """ + pattern_ast = rule.get("pattern_ast") + rewrite_ast = rule.get("rewrite_ast") + if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): + raise TypeError("rule ASTs must be Node instances") + return [RuleGeneratorV2.merge_variable_list(rule, variable_list) for variable_list in RuleGeneratorV2.variable_lists(pattern_ast, rewrite_ast)] + + @staticmethod + def merge_variable_list(rule: RuleV2, variable_list: List[str]) -> RuleV2: + """Return a new rule where the given element variables are collapsed into a single set variable <>. + + Allocates the next available set variable and rewrites both ASTs (and their deparsed forms) so consecutive members of variable_list share that one set variable. The input rule is not mutated. + """ + new_rule = copy.deepcopy(rule) + mapping = copy.deepcopy(new_rule["mapping"]) + if not isinstance(mapping, dict): + raise TypeError("rule['mapping'] must be a dict[str, str]") + + mapping, set_name = RuleGeneratorV2._find_next_set_variable(mapping) + new_rule["mapping"] = mapping + + var_set = set(variable_list) + for key in ("pattern_ast", "rewrite_ast"): + ast = new_rule.get(key) + if not isinstance(ast, Node): + raise TypeError(f"rule['{key}'] must be an AST Node") + new_rule[key] = RuleGeneratorV2._merge_variable_list_in_ast(ast, var_set, set_name) + + RuleGeneratorV2._sync_rule_strings(new_rule) + return new_rule + + @staticmethod + def drop_branches(rule: RuleV2) -> List[Dict[str, object]]: + """Return one child rule per droppable branch (a clause or AND/OR conjunct that is fully variablized on both sides). + + Each child removes one branch from both pattern and rewrite, producing a strictly more general rule. + """ + pattern_ast = rule.get("pattern_ast") + rewrite_ast = rule.get("rewrite_ast") + if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): + raise TypeError("rule ASTs must be Node instances") + return [RuleGeneratorV2.drop_branch(rule, branch) for branch in RuleGeneratorV2.branches(pattern_ast, rewrite_ast)] + + @staticmethod + def drop_branch(rule: RuleV2, branch: Dict[str, object]) -> RuleV2: + """Return a new rule with branch removed from both pattern and rewrite ASTs. + + branch is a descriptor produced by branches (e.g. {"key": "where", "value": ...}). The input rule is not mutated. + """ + new_rule = copy.deepcopy(rule) + for key in ("pattern_ast", "rewrite_ast"): + ast = new_rule.get(key) + if not isinstance(ast, Node): + raise TypeError(f"rule['{key}'] must be an AST Node") + new_rule[key] = RuleGeneratorV2._drop_branch_in_ast(ast, branch) + RuleGeneratorV2._sync_rule_strings(new_rule) + return new_rule + + @staticmethod + def generalize_tables(rule: RuleV2) -> RuleV2: + """Return a new rule with every replaceable table variabilized in one pass. + + Walks the candidate tables and applies variablize_table repeatedly. Returns a fresh dict; the input rule is not mutated. + """ + new_rule = copy.deepcopy(rule) + pattern_ast = new_rule.get("pattern_ast") + rewrite_ast = new_rule.get("rewrite_ast") + if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): + raise TypeError("rule ASTs must be Node instances") + for table in RuleGeneratorV2.tables(pattern_ast, rewrite_ast): + new_rule = RuleGeneratorV2.variablize_table(new_rule, table) + pattern_ast = new_rule.pattern_ast + rewrite_ast = new_rule.rewrite_ast + return new_rule + + @staticmethod + def generalize_columns(rule: RuleV2) -> RuleV2: + """Return a new rule with every replaceable column variabilized in one pass. + + Returns a fresh dict; the input is not mutated. + """ + new_rule = copy.deepcopy(rule) + pattern_ast = new_rule.get("pattern_ast") + rewrite_ast = new_rule.get("rewrite_ast") + if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): + raise TypeError("rule ASTs must be Node instances") + for column in RuleGeneratorV2.columns(pattern_ast, rewrite_ast): + new_rule = RuleGeneratorV2.variablize_column(new_rule, column) + pattern_ast = new_rule.pattern_ast + rewrite_ast = new_rule.rewrite_ast + return new_rule + + @staticmethod + def generalize_literals(rule: RuleV2) -> RuleV2: + """Return a new rule with every replaceable literal variabilized in one pass. + + Returns a fresh dict; the input is not mutated. + """ + new_rule = copy.deepcopy(rule) + pattern_ast = new_rule.get("pattern_ast") + rewrite_ast = new_rule.get("rewrite_ast") + if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): + raise TypeError("rule ASTs must be Node instances") + for literal in RuleGeneratorV2.literals(pattern_ast, rewrite_ast): + new_rule = RuleGeneratorV2.variablize_literal(new_rule, literal) + pattern_ast = new_rule.pattern_ast + rewrite_ast = new_rule.rewrite_ast + return new_rule + + @staticmethod + def generalize_subtrees(rule: RuleV2) -> RuleV2: + """Return a new rule with every shared, fully-variablized subtree collapsed into a single element variable. + + Returns a fresh dict; the input is not mutated. + """ + new_rule = copy.deepcopy(rule) + pattern_ast = new_rule.get("pattern_ast") + rewrite_ast = new_rule.get("rewrite_ast") + if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): + raise TypeError("rule ASTs must be Node instances") + for subtree in RuleGeneratorV2.subtrees(pattern_ast, rewrite_ast): + new_rule = RuleGeneratorV2.variablize_subtree(new_rule, subtree) + pattern_ast = new_rule.pattern_ast + rewrite_ast = new_rule.rewrite_ast + return new_rule + + @staticmethod + def generalize_variables(rule: RuleV2) -> RuleV2: + """Return a new rule with every mergeable element-variable list collapsed into a set variable. + + Returns a fresh dict; the input is not mutated. + """ + new_rule = copy.deepcopy(rule) + pattern_ast = new_rule.get("pattern_ast") + rewrite_ast = new_rule.get("rewrite_ast") + if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): + raise TypeError("rule ASTs must be Node instances") + for variable_list in RuleGeneratorV2.variable_lists(pattern_ast, rewrite_ast): + if variable_list: + new_rule = RuleGeneratorV2.merge_variable_list(new_rule, variable_list) + pattern_ast = new_rule.pattern_ast + rewrite_ast = new_rule.rewrite_ast + return new_rule + + @staticmethod + def generalize_branches(rule: RuleV2) -> RuleV2: + """Return a new rule with every droppable branch removed in one pass. + + Returns a fresh dict; the input is not mutated. + """ + new_rule = copy.deepcopy(rule) + pattern_ast = new_rule.get("pattern_ast") + rewrite_ast = new_rule.get("rewrite_ast") + if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): + raise TypeError("rule ASTs must be Node instances") + for branch in RuleGeneratorV2.branches(pattern_ast, rewrite_ast): + new_rule = RuleGeneratorV2.drop_branch(new_rule, branch) + pattern_ast = new_rule.pattern_ast + rewrite_ast = new_rule.rewrite_ast + return new_rule + + + @staticmethod + def columns(pattern_ast: Node, rewrite_ast: Node) -> List[str]: + """Return the deterministic, sorted set of un-variablized column names in pattern_ast. + + Variable columns are represented as ElementVariableNode, so isinstance(node, ColumnNode) + naturally excludes them — only concrete column names are returned. + rewrite_ast is accepted but ignored. + """ + del rewrite_ast # accepted for API compatibility + found: Set[str] = set() + for node in RuleGeneratorV2._walk(pattern_ast): + if isinstance(node, ColumnNode) and node.name: + found.add(node.name) + # Sort deterministically so generalize_columns is hash-seed independent. + return sorted(found) + + @staticmethod + def literals(pattern_ast: Node, rewrite_ast: Node) -> List[Union[str, numbers.Number]]: + """Return literals worth variabilizing across the pattern and rewrite ASTs. + + Includes any literal that recurs more than once on either side, plus any literal that appears on both sides. + """ + pattern_literals = RuleGeneratorV2._literal_counts(pattern_ast) + rewrite_literals = RuleGeneratorV2._literal_counts(rewrite_ast) + + variablize_literals: List[Union[str, numbers.Number]] = [ + lit for lit, count in pattern_literals.items() if count > 1 + ] + [lit for lit, count in rewrite_literals.items() if count > 1] + + intersect_literals = set(pattern_literals.keys()).intersection(set(rewrite_literals.keys())) + return list(set(variablize_literals).union(intersect_literals)) + + @staticmethod + def _literal_counts(ast: Node) -> Dict[Union[str, numbers.Number], int]: + """Count how often each literal value appears in ast, ignoring placeholder-named string literals. + + String literals are normalized by stripping % so that 'foo%' and 'foo' collapse together. + """ + counts: Dict[Union[str, numbers.Number], int] = {} + for node in RuleGeneratorV2._walk(ast): + if node.type != NodeType.LITERAL: + continue + value = getattr(node, "value", None) + if isinstance(value, str): + counts[value.replace("%", "")] = counts.get(value.replace("%", ""), 0) + 1 + elif isinstance(value, numbers.Number): + counts[value] = counts.get(value, 0) + 1 + return counts + + @staticmethod + def tables(pattern_ast: Node, rewrite_ast: Node) -> List[Dict[str, str]]: + """Return deduplicated table references ({"value", "name"} dicts) from the pattern, augmented with any additional rewrite-side aliases for the same table names. + + Tables that appear only in the rewrite are excluded — they cannot be variablized since they have no pattern-side binding. + """ + pattern_tables = RuleGeneratorV2._tables_of_ast(pattern_ast) + rewrite_tables = RuleGeneratorV2._tables_of_ast(rewrite_ast) + + pattern_set: Dict[str, List[str]] = defaultdict(list) + rewrite_set: Dict[str, List[str]] = defaultdict(list) + + for table in pattern_tables: + value = table["value"] + alias = table["name"] + if alias not in pattern_set[value]: + pattern_set[value].append(alias) + + for table in rewrite_tables: + value = table["value"] + alias = table["name"] + if alias not in rewrite_set[value]: + rewrite_set[value].append(alias) + + superset: List[Dict[str, str]] = [] + for value, pattern_aliases in pattern_set.items(): + rewrite_aliases = rewrite_set.get(value, []) + merged_aliases = pattern_aliases + [a for a in rewrite_aliases if a not in pattern_aliases] + for alias in merged_aliases: + superset.append({"value": value, "name": alias}) + + deduped: List[Dict[str, str]] = [] + seen = set() + for table in superset: + fingerprint = f"{table['value']}-{table['name']}" + if fingerprint not in seen: + deduped.append(table) + seen.add(fingerprint) + return deduped + + @staticmethod + def _tables_of_ast(ast: Node) -> List[Dict[str, str]]: + """Return {"value", "name"} descriptors for every concrete TableNode in ast. + + Variable tables are represented as ElementVariableNode, so isinstance(node, TableNode) + naturally excludes them — only concrete table references are returned. + name is the alias when present, otherwise the table value. + """ + found: List[Dict[str, str]] = [] + for node in RuleGeneratorV2._walk(ast): + if not isinstance(node, TableNode): + continue + if not isinstance(node.name, str): + continue + alias = node.alias if isinstance(node.alias, str) else node.name + found.append({"value": node.name, "name": alias}) + return found + + @staticmethod + def variable_lists(pattern_ast: Node, rewrite_ast: Node) -> List[List[str]]: + """Return element-variable name lists that appear in both pattern and rewrite (intersected pairwise). + + Each returned list is the intersection of one pattern-side AND/SELECT chain with the first matching rewrite-side chain, suitable for collapsing into a set variable. + """ + pattern_lists = [set(v) for v in RuleGeneratorV2._variable_lists_of_ast(pattern_ast)] + rewrite_lists = [set(v) for v in RuleGeneratorV2._variable_lists_of_ast(rewrite_ast)] + + ans: List[List[str]] = [] + while pattern_lists: + p = pattern_lists.pop() + matched_idx: Optional[int] = None + for idx, r in enumerate(rewrite_lists): + inter = p.intersection(r) + if inter: + ans.append(list(inter)) + matched_idx = idx + break + if matched_idx is not None: + rewrite_lists.pop(matched_idx) + return ans + + @staticmethod + def _variable_lists_of_ast(ast: Node) -> List[List[str]]: + """Collect element-variable name lists from mergeable positions. + + Mergeable positions include SELECT items, top-level AND chains, single-WHERE predicates, LIMIT placeholders, and JOIN ON placeholders. AND chains are flattened across their full left-associative depth so a AND b AND c yields a single 3-name list. + """ + # AND chains parse left-associatively, for example a AND b AND c + # becomes (a AND b) AND c. Collect lists only at top-most AND + # operators, where the parent is not also AND, and flatten the whole + # chain into a single list of placeholder names. + out: List[List[str]] = [] + + def _flatten_and(node: Node) -> List[str]: + if isinstance(node, OperatorNode) and node.name.lower() == "and": + names: List[str] = [] + for child in node.children: + names.extend(_flatten_and(child)) + return names + if isinstance(node, ElementVariableNode) and node.parent_alias is None: + # Only bare variables (not qualified column vars like .) + return [node.name] + return [] + + seen_and_ids: Set[int] = set() + + def _is_inside_and(parent: Optional[Node]) -> bool: + return ( + parent is not None + and isinstance(parent, OperatorNode) + and parent.name.lower() == "and" + ) + + def _visit(node: Node, parent: Optional[Node] = None) -> None: + if isinstance(node, SelectNode): + if not getattr(node, "distinct", False): + names: List[str] = [] + for child in node.children: + if isinstance(child, ElementVariableNode) and child.parent_alias is None: + # Only bare variables (not qualified column vars like .) + names.append(child.name) + if names: + out.append(names) + elif ( + isinstance(node, OperatorNode) + and node.name.lower() == "and" + and not _is_inside_and(parent) + ): + names = _flatten_and(node) + if names: + out.append(names) + seen_and_ids.add(id(node)) + elif isinstance(node, WhereNode) and len(node.children) == 1 and isinstance(node.children[0], ElementVariableNode): + out.append([node.children[0].name]) + elif isinstance(node, LimitNode) and isinstance(node.limit, ElementVariableNode): + out.append([node.limit.name]) + elif isinstance(node, JoinNode) and node.on_condition is not None: + oc = node.on_condition + if isinstance(oc, ElementVariableNode): + out.append([oc.name]) + + children = getattr(node, "children", None) + if children: + for child in children: + if isinstance(child, Node): + _visit(child, node) + + _visit(ast) + return out + + # _variable_lists_of_ast uses recursive AST traversal. The following + # nested-list helpers remain for _merge_variable_list_in_ast. + + @staticmethod + def subtrees(pattern_ast: Node, rewrite_ast: Node) -> List[Node]: + """Return subtrees that appear (structurally equal) in both pattern and rewrite, eligible to share an element variable. + + Pairs are matched first-fit between the two sides' candidate lists. + """ + pattern_subtrees = RuleGeneratorV2._subtrees_of_ast(pattern_ast) + rewrite_subtrees = RuleGeneratorV2._subtrees_of_ast(rewrite_ast) + ans: List[Node] = [] + while pattern_subtrees: + pattern_subtree = pattern_subtrees.pop() + for idx, rewrite_subtree in enumerate(rewrite_subtrees): + if pattern_subtree == rewrite_subtree: + ans.append(pattern_subtree) + rewrite_subtrees.pop(idx) + break + return ans + + @staticmethod + def _subtrees_of_ast(ast: Node) -> List[Node]: + """Return deep copies of every fully-variablized subtree candidate inside ast. + + A subtree is included only if _is_subtree_candidate accepts it for its parent context, and duplicates are de-duped by deparsed (or structural) key. + """ + out: List[Node] = [] + seen: Set[str] = set() + + def _visit(node: Node, parent: Optional[Node] = None) -> None: + if RuleGeneratorV2._is_subtree_candidate(node, parent): + try: + key = RuleGeneratorV2.deparse(node) + except Exception: + key = RuleGeneratorV2._structural_key(node) + if key not in seen: + seen.add(key) + out.append(copy.deepcopy(node)) + children = getattr(node, "children", None) + if isinstance(children, list): + for child in children: + if isinstance(child, Node): + _visit(child, node) + elif isinstance(children, set): + for child in children: + if isinstance(child, Node): + _visit(child, node) + + _visit(ast) + return out + + @staticmethod + def _structural_key(node: Node) -> str: + """Return a stable string fingerprint of node based on its type, scalar attributes, and recursively-keyed children. + + Used as a fallback dedup key in _subtrees_of_ast when deparse cannot render a node. + """ + parts: List[str] = [type(node).__name__] + for attr in ("name", "value", "alias", "distinct", "parent_alias"): + if hasattr(node, attr): + parts.append(f"{attr}={getattr(node, attr)!r}") + children = getattr(node, "children", None) or [] + if isinstance(children, (list, set)): + child_keys: List[str] = [] + for child in list(children): + if isinstance(child, Node): + child_keys.append(RuleGeneratorV2._structural_key(child)) + else: + child_keys.append(repr(child)) + parts.append("(" + ",".join(child_keys) + ")") + return "|".join(parts) + + @staticmethod + def _is_subtree_candidate(node: Node, parent: Optional[Node] = None) -> bool: + """Return True when node is a position-aware subtree replaceable by an element variable. + + Column and literal nodes only qualify in SELECT, GROUP BY, or ORDER BY positions. Set-variable nodes qualify under SELECT, single-WHERE, single-WHEN, or OR-chain parents. Other nodes must have at least one variabilized child and no un-variabilized leaves. + """ + if isinstance( + node, + ( + QueryNode, + CompoundQueryNode, + CaseNode, + SelectNode, + FromNode, + WhereNode, + GroupByNode, + HavingNode, + JoinNode, + OrderByItemNode, + OrderByNode, + LimitNode, + SubqueryNode, + WhenThenNode, + ), + ): + return False + + if isinstance(node, ElementVariableNode): + # Column variables (qualified or bare) are subtree candidates only as + # standalone SELECT, GROUP BY, or ORDER BY items. + return isinstance(parent, (SelectNode, GroupByNode, OrderByItemNode)) + + if isinstance(node, SetVariableNode): + # SELECT-position set vars can be lifted into a fresh element var + # during SELECT/GROUP BY split iterations. + if isinstance(parent, SelectNode): + return True + # A fully collapsed AND chain qualifies only when the set var + # stands alone as a WHERE or WHEN predicate, or as an OR branch. + # If it is mixed with other conjuncts under an AND, keep it as a + # set variable. + if isinstance(parent, (WhereNode, WhenThenNode)): + return True + if ( + isinstance(parent, OperatorNode) + and parent.name.lower() == "or" + ): + return True + return False + + if isinstance(node, (LiteralNode, VariableLiteralNode)): + return False + + var_count = 0 + for child in getattr(node, "children", []) or []: + if isinstance(child, (QueryNode, CompoundQueryNode, SelectNode, FromNode, WhereNode, JoinNode, SubqueryNode)): + return False + if isinstance(child, list): + return False + if isinstance(child, Node): + if isinstance(child, (ElementVariableNode, SetVariableNode, VariableLiteralNode)): + var_count += 1 + continue + if isinstance(child, LiteralNode): + continue + return False + return var_count >= 1 + + @staticmethod + def branches(pattern_ast: Node, rewrite_ast: Node) -> List[Dict[str, object]]: + """Return branch descriptors (clauses or AND/OR conjuncts) that exist on both sides and are fully variablized. + + Each entry is a {"key": ..., "value": ...} dict suitable for drop_branch. Pairs are matched first-fit; only matched branches are returned. + """ + pattern_branches = RuleGeneratorV2._branch_entries_of_ast(pattern_ast) + rewrite_branches = RuleGeneratorV2._branch_entries_of_ast(rewrite_ast) + out: List[Dict[str, object]] = [] + remaining = list(rewrite_branches) + while pattern_branches: + pb_public, pb_target = pattern_branches.pop() + for idx, (rb_public, rb_target) in enumerate(remaining): + if RuleGeneratorV2._branch_values_match(pb_public, rb_public, pb_target, rb_target): + out.append(pb_public) + remaining.pop(idx) + break + return out + + @staticmethod + def _branch_values_match( + pb: Dict[str, object], + rb: Dict[str, object], + pb_target: object, + rb_target: object, + ) -> bool: + if pb.get("key") != rb.get("key"): + return False + return RuleGeneratorV2._branch_targets_match(pb_target, rb_target) + + @staticmethod + def _branch_targets_match(pb_target: object, rb_target: object) -> bool: + if pb_target == rb_target: + return True + if isinstance(pb_target, Node) and isinstance(rb_target, Node): + try: + ps = RuleGeneratorV2.deparse(copy.deepcopy(pb_target)) + rs = RuleGeneratorV2.deparse(copy.deepcopy(rb_target)) + except Exception: + return False + return ps.lower() == rs.lower() + return False + + @staticmethod + def _branch_entries_of_ast(ast: Node) -> List[Tuple[Dict[str, object], object]]: + """Enumerate (public_descriptor, internal_target) pairs for every branch in ast that branches could potentially drop. + + Handles full queries, AND/OR chains with one entry per conjunct or disjunct, and equality RHS singletons. Public descriptors are the dicts surfaced by branches; internal targets are the actual nodes used by _drop_branch_in_ast. + """ + if isinstance(ast, QueryNode): + out: List[Tuple[Dict[str, object], object]] = [] + select = RuleGeneratorV2._first_clause(ast, NodeType.SELECT) + from_clause = RuleGeneratorV2._first_clause(ast, NodeType.FROM) + where = RuleGeneratorV2._first_clause(ast, NodeType.WHERE) + group_by = RuleGeneratorV2._first_clause(ast, NodeType.GROUP_BY) + having = RuleGeneratorV2._first_clause(ast, NodeType.HAVING) + order_by = RuleGeneratorV2._first_clause(ast, NodeType.ORDER_BY) + limit = RuleGeneratorV2._first_clause(ast, NodeType.LIMIT) + offset = RuleGeneratorV2._first_clause(ast, NodeType.OFFSET) + # Treat SELECT and SELECT DISTINCT as separate branch categories. + select_is_distinct = isinstance(select, SelectNode) and bool(getattr(select, "distinct", False)) + plain_select = select if (select is not None and not select_is_distinct) else None + is_select_only_wrapper = ( + select is not None + and from_clause is None + and where is None + and all(clause is None for clause in (group_by, having, order_by, limit, offset)) + ) + if select is not None and ( + is_select_only_wrapper or RuleGeneratorV2._is_branch_clause("select", select) + ): + select_target: object = select + if is_select_only_wrapper: + select_target = "__select_wrapper__" + if isinstance(select, SelectNode) and len(select.children) == 1: + child = select.children[0] + if isinstance(child, SetVariableNode): + out.append(({"key": "select", "value": "set_variable"}, select_target)) + elif isinstance(child, ColumnNode) and child.name == "*": + out.append(({"key": "select", "value": "all_columns"}, select_target)) + else: + out.append(({"key": "select", "value": None}, select_target)) + else: + out.append(({"key": "select", "value": None}, select_target)) + is_from_only_wrapper = ( + from_clause is not None + and select is None + and where is None + and all(clause is None for clause in (group_by, having, order_by, limit, offset)) + ) + if from_clause is not None and ( + is_from_only_wrapper or RuleGeneratorV2._is_branch_clause("from", from_clause) + ): + from_target: object = from_clause + if is_from_only_wrapper: + from_target = "__from_wrapper__" + if isinstance(from_clause, FromNode): + if any(isinstance(c, JoinNode) for c in from_clause.children): + out.append(({"key": "from", "value": "join_sources"}, from_target)) + else: + out.append(({"key": "from", "value": "table_sources"}, from_target)) + else: + out.append(({"key": "from", "value": None}, from_target)) + is_where_only_wrapper = ( + where is not None + and select is None + and from_clause is None + and all(clause is None for clause in (group_by, having, order_by, limit, offset)) + ) + if where is not None and ( + is_where_only_wrapper or RuleGeneratorV2._is_branch_clause("where", where) + ): + where_target: object = where + if is_where_only_wrapper: + where_target = "__where_wrapper__" + out.append(({"key": "where", "value": None}, where_target)) + if group_by is not None and RuleGeneratorV2._is_branch_clause("group_by", group_by): + out.append(({"key": "group_by", "value": None}, group_by)) + if having is not None and RuleGeneratorV2._is_branch_clause("having", having): + out.append(({"key": "having", "value": None}, having)) + if order_by is not None and RuleGeneratorV2._is_branch_clause("order_by", order_by): + out.append(({"key": "order_by", "value": None}, order_by)) + if limit is not None and RuleGeneratorV2._is_branch_clause("limit", limit): + out.append(({"key": "limit", "value": None}, limit)) + if offset is not None and RuleGeneratorV2._is_branch_clause("offset", offset): + out.append(({"key": "offset", "value": None}, offset)) + + # Apply SELECT/WHERE/FROM interactions. DISTINCT selects do not + # count as plain SELECT for these rules. + if plain_select is not None and where is not None: + out = [entry for entry in out if entry[0]["key"] != "from"] + if plain_select is None and from_clause is not None: + out = [entry for entry in out if entry[0]["key"] != "where"] + return out + + if isinstance(ast, OperatorNode) and ast.name.lower() in {"and", "or"}: + out = [] + for child in list(ast.children): + wrapped = OperatorNode(copy.deepcopy(child), ast.name.upper()) + if RuleGeneratorV2._is_branch_node(wrapped): + out.append(({"key": ast.name.lower(), "value": child}, child)) + return out + + if isinstance(ast, OperatorNode): + children = list(ast.children) + if ast.name == "=" and len(children) == 2: + return [({"key": "eq_rhs", "value": children[1]}, children[1])] + + return [] + + @staticmethod + def _is_branch_clause(key: str, clause: Node) -> bool: + if key == "select": + if isinstance(clause, SelectNode): + if len(clause.children) == 1: + child = clause.children[0] + if isinstance(child, ColumnNode) and child.name == "*": + return True + if isinstance(child, SetVariableNode): + return True + return RuleGeneratorV2._is_branch_node(child) + return RuleGeneratorV2._is_branch_node(clause) + return False + if key == "from": + if isinstance(clause, FromNode): + return RuleGeneratorV2._is_branch_node(clause) + return False + if key == "where": + if isinstance(clause, WhereNode) and len(clause.children) == 1: + return RuleGeneratorV2._is_branch_node(clause.children[0]) + return RuleGeneratorV2._is_branch_node(clause) + + @staticmethod + def _is_branch_node(node: Node) -> bool: + if isinstance(node, FromNode): + for child in node.children: + if isinstance(child, ElementVariableNode): + # generator-variablized table — ok + pass + elif isinstance(child, TableNode): + # Concrete table — branch is not fully variablized. + return False + elif isinstance(child, JoinNode): + if not RuleGeneratorV2._is_branch_node(child): + return False + else: + return False + return True + if isinstance(node, JoinNode): + # A JOIN counts as a branch source when all of its operands and + # the optional ON-condition contain nothing un-variablized. + for child in node.children: + if isinstance(child, ElementVariableNode): + # generator-variablized table — ok + pass + elif isinstance(child, TableNode): + return False + else: + if RuleGeneratorV2._tables_of_ast(child): + return False + cols = RuleGeneratorV2.columns(child, child) + if cols and not (len(cols) == 1 and cols[0] == "*"): + return False + if RuleGeneratorV2._literal_counts(child): + return False + if RuleGeneratorV2._variable_lists_of_ast(child): + return False + return True + if isinstance(node, WhereNode): + predicates = list(node.children) + if len(predicates) == 1: + return RuleGeneratorV2._is_branch_node(predicates[0]) + return False + if RuleGeneratorV2._tables_of_ast(node): + return False + columns = RuleGeneratorV2.columns(node, node) + if columns: + return len(columns) == 1 and columns[0] == "*" + if RuleGeneratorV2._literal_counts(node): + return False + if RuleGeneratorV2._variable_lists_of_ast(node): + return False + return True + + @staticmethod + def _replace_column_in_ast(ast: Node, column: str, external_name: str) -> Node: + """Replace every ColumnNode whose name == column (and any non-DISTINCT SELECT *) with an ElementVariableNode(external_name) in ast. + + The first column variabilized also captures bare * in plain SELECT clauses, so they share a single variable. SELECT DISTINCT * is kept separate and is only rewritten when the requested column itself is *. + """ + # Every column variabilization also rewrites any remaining plain + # SELECT * to the same variable. This causes the first column processed + # to share its variable with *. SELECT DISTINCT * is kept separate and + # is only rewritten when the requested column itself is *. + non_distinct_select_star_ids: Set[int] = set() + if column != "*": + for node in RuleGeneratorV2._walk(ast): + if isinstance(node, SelectNode) and not getattr(node, "distinct", False): + for child in node.children: + if isinstance(child, ColumnNode) and child.name == "*": + non_distinct_select_star_ids.add(id(child)) + + to_replace = [] + for node in RuleGeneratorV2._walk(ast): + if not isinstance(node, ColumnNode): + continue + if node.name == column: + to_replace.append((node, ElementVariableNode(external_name, parent_alias=node.parent_alias, alias=node.alias))) + elif ( + node.name == "*" + and column != "*" + and id(node) in non_distinct_select_star_ids + ): + to_replace.append((node, ElementVariableNode(external_name, parent_alias=node.parent_alias, alias=node.alias))) + + for old_node, new_node in to_replace: + if old_node is ast: + ast = new_node + else: + RuleGeneratorV2._replace_node_reference(ast, old_node, new_node) + return ast + + @staticmethod + def _replace_literal_in_ast( + ast: Node, + literal: Union[str, numbers.Number], + external_name: str, + ) -> Node: + """Substitute every occurrence of literal in ast with the new variable. + + String literals become VariableLiteralNode (preserving surrounding % wildcards). + Numeric literals and LIMIT/OFFSET values become ElementVariableNode. + """ + to_replace = [] + for node in RuleGeneratorV2._walk(ast): + if isinstance(node, LimitNode): + if isinstance(literal, numbers.Number) and node.limit == literal: + node.limit = ElementVariableNode(external_name) + continue + if node.type != NodeType.LITERAL: + continue + value = getattr(node, "value", None) + + if isinstance(literal, str) and isinstance(value, str): + if value == literal: + to_replace.append((node, VariableLiteralNode(external_name))) + elif value.replace("%", "") == literal: + prefix = "%" if value.startswith("%") else "" + suffix = "%" if value.endswith("%") else "" + to_replace.append((node, VariableLiteralNode(external_name, prefix=prefix, suffix=suffix))) + continue + + if isinstance(literal, numbers.Number) and isinstance(value, numbers.Number) and value == literal: + to_replace.append((node, ElementVariableNode(external_name))) + + for old_node, new_node in to_replace: + if old_node is ast: + ast = new_node + else: + RuleGeneratorV2._replace_node_reference(ast, old_node, new_node) + return ast + + @staticmethod + def _replace_table_in_ast( + ast: Node, + target_value: str, + target_name: str, + placeholder_token: str, + ) -> Node: + """Replace every matching TableNode with an ElementVariableNode(placeholder_token) and update qualified column refs in ast. + + A bare-named reference to target_value is also matched even when its alias disagrees with target_name, so a single variable can cover both an aliased outer reference and a bare-named reference inside a subquery. + placeholder_token here is actually the external_name (e.g. "x1") passed from variablize_table. + ColumnNode.parent_alias is set to ElementVariableNode(placeholder_token) so the formatter and rewriter handle it via isinstance checks. + """ + # A bare-table reference, with no explicit alias, is also matched when + # its value equals the target's value even if target_name differs. This + # lets one table variable cover both an aliased outer reference and a + # bare-named reference to the same underlying table. + match_aliases: Set[str] = set() + to_replace = [] + for node in RuleGeneratorV2._walk(ast): + if not isinstance(node, TableNode): + continue + current_alias = node.alias if isinstance(node.alias, str) else node.name + if node.name == target_value and ( + current_alias == target_name or current_alias == node.name + ): + match_aliases.add(current_alias) + to_replace.append((node, ElementVariableNode(placeholder_token))) + + for old_node, new_node in to_replace: + if old_node is ast: + ast = new_node + else: + RuleGeneratorV2._replace_node_reference(ast, old_node, new_node) + + if not match_aliases: + return ast + + # Column refs may use either the alias (t1.col), the table value + # (schema.table.col), or the target alias carried by the paired rule + # side. All of those prefixes should pick up the same table variable. + for node in RuleGeneratorV2._walk(ast): + if ( + isinstance(node, ColumnNode) + and isinstance(node.parent_alias, str) + and ( + node.parent_alias in match_aliases + or node.parent_alias == target_value + or node.parent_alias == target_name + ) + ): + node.parent_alias = ElementVariableNode(placeholder_token) + return ast + + @staticmethod + def _replace_subtree_in_ast(ast: Node, subtree: Node, replacement: Node, parent: Optional[Node] = None) -> Node: + """Position-aware replacement of every occurrence of subtree inside ast with a deep copy of replacement. + + Only swaps a match when the current parent context would have collected it as a candidate (so a column ref inside a JOIN ON predicate is left alone even when the same column is replaced as a SELECT item). Mutates and returns ast; replacement is deep-copied per substitution. + """ + # Subtree replacement is position-aware. A ColumnNode or LiteralNode + # is structurally the same object shape regardless of context, so only + # replace it when the current position is one where it would have been + # collected as a subtree candidate. + if ast == subtree and RuleGeneratorV2._is_subtree_candidate(ast, parent): + return copy.deepcopy(replacement) + if isinstance(ast, JoinNode): + had_on = ast.on_condition is not None + n_using = len(ast.using) if ast.using else 0 + children = getattr(ast, "children", None) + if isinstance(children, list): + for idx, child in enumerate(children): + if isinstance(child, Node): + new_child = RuleGeneratorV2._replace_subtree_in_ast(child, subtree, replacement, ast) + if new_child is not child: + children[idx] = new_child + RuleGeneratorV2._resync_parallel_attrs(ast, child, new_child) + elif isinstance(children, set): + replacements: List[Tuple[Node, Node]] = [] + new_children: Set[Node] = set() + for child in children: + new_child = RuleGeneratorV2._replace_subtree_in_ast(child, subtree, replacement, ast) + new_children.add(new_child) + if new_child is not child: + replacements.append((child, new_child)) + ast.children = new_children + for old, new in replacements: + RuleGeneratorV2._resync_parallel_attrs(ast, old, new) + + if isinstance(ast, JoinNode): + RuleGeneratorV2._resync_join_attrs(ast, had_on, n_using) + elif isinstance(ast, UnaryOperatorNode): + ast.operand = ast.children[0] + elif isinstance(ast, CompoundQueryNode): + ast.left = ast.children[0] + ast.right = ast.children[1] + elif isinstance(ast, SubqueryNode) and isinstance(ast.children, set): + pass + return ast + + @staticmethod + def _drop_branch_in_ast(ast: Node, branch: Dict[str, object]) -> Node: + """Return a new AST with the branch described by branch removed from ast. + + Handles AND/OR conjunct removal, equality RHS unwrapping, and per-clause QueryNode trimming. Dropping a sole FROM that wraps a subquery returns the inner query. May return the original ast if no branch matches. + """ + if isinstance(ast, OperatorNode): + key = branch.get("key") + if key == "eq_rhs": + children = list(ast.children) + if ast.name == "=" and len(children) == 2 and children[1] == branch.get("value"): + return children[0] + if key == ast.name.lower(): + children = list(ast.children) + remaining = [child for child in children if child != branch.get("value")] + if len(remaining) == 1: + return remaining[0] + ast.children = remaining + return ast + return ast + + if not isinstance(ast, QueryNode): + return ast + key = branch.get("key") + if key == "select": + sel = RuleGeneratorV2._first_clause(ast, NodeType.SELECT) + reduced = RuleGeneratorV2._query_without_clause(ast, NodeType.SELECT) + if isinstance(reduced, QueryNode) and isinstance(sel, SelectNode): + if not any( + RuleGeneratorV2._first_clause(reduced, t) + for t in ( + NodeType.SELECT, + NodeType.FROM, + NodeType.WHERE, + NodeType.GROUP_BY, + NodeType.HAVING, + NodeType.ORDER_BY, + NodeType.LIMIT, + NodeType.OFFSET, + ) + ): + if len(sel.children) == 1: + return sel.children[0] + return reduced + if key == "from": + from_clause = RuleGeneratorV2._first_clause(ast, NodeType.FROM) + reduced = RuleGeneratorV2._query_without_clause(ast, NodeType.FROM) + # When FROM is the only clause and contains a single subquery, + # unwrap it to the subquery's inner query. + if ( + isinstance(reduced, QueryNode) + and len(reduced.children) == 0 + and isinstance(from_clause, FromNode) + and len(from_clause.children) == 1 + ): + source = next(iter(from_clause.children)) + if isinstance(source, SubqueryNode): + inner = next(iter(source.children), None) + if isinstance(inner, Node): + return inner + return reduced + if key == "where": + reduced = RuleGeneratorV2._query_without_clause(ast, NodeType.WHERE) + # If this was a WHERE-scope wrapper, unwrap back to condition expression. + if isinstance(reduced, QueryNode) and len(reduced.children) == 0: + wh = RuleGeneratorV2._first_clause(ast, NodeType.WHERE) + if isinstance(wh, WhereNode) and len(wh.children) == 1: + return wh.children[0] + return reduced + if key == "group_by": + return RuleGeneratorV2._query_without_clause(ast, NodeType.GROUP_BY) + if key == "having": + return RuleGeneratorV2._query_without_clause(ast, NodeType.HAVING) + if key == "order_by": + return RuleGeneratorV2._query_without_clause(ast, NodeType.ORDER_BY) + if key == "limit": + return RuleGeneratorV2._query_without_clause(ast, NodeType.LIMIT) + if key == "offset": + return RuleGeneratorV2._query_without_clause(ast, NodeType.OFFSET) + return ast + + @staticmethod + def _merge_variable_list_in_ast(ast: Node, variable_set: Set[str], set_name: str) -> Node: + """Collapse element variables in variable_set into a single SetVariableNode(set_name) wherever they appear in ast. + + Handles SELECT/GROUP BY lists, flattened AND chains, single-WHERE predicates, JOIN ON conditions, and LIMIT placeholders. Mutates ast in place and returns it. + """ + def _process_and_chain(and_node: OperatorNode) -> Optional[Node]: + # Flatten nested AND chains so that (a AND b) AND c is treated as + # one ordered list of predicates. + flat: List[Node] = [] + + def _flatten(n: Node) -> None: + if isinstance(n, OperatorNode) and n.name.lower() == "and": + for child in n.children: + if isinstance(child, Node): + _flatten(child) + return + flat.append(n) + + _flatten(and_node) + + flat_var_names = {c.name for c in flat if isinstance(c, ElementVariableNode)} + if not variable_set.issubset(flat_var_names): + return None + + new_children: List[Node] = [] + pending = False + for child in flat: + if isinstance(child, ElementVariableNode) and child.name in variable_set: + if not pending: + new_children.append(SetVariableNode(set_name)) + pending = True + continue + new_children.append(child) + + if len(new_children) == 1: + return new_children[0] + result: Node = new_children[0] + for child in new_children[1:]: + result = OperatorNode(result, "AND", child) + return result + + def _is_inside_and(parent: Optional[Node]) -> bool: + return ( + parent is not None + and isinstance(parent, OperatorNode) + and parent.name.lower() == "and" + ) + + def _visit(node: Node, parent: Optional[Node]) -> Node: + if isinstance(node, (SelectNode, GroupByNode)): + # Variable lists are discovered from SELECT and AND positions, + # but replacement still walks related list-bearing clauses and + # collapses any subset match. Apply that to GROUP BY so a + # singleton merged on the SELECT side also collapses the same + # column ref in the GROUP BY clause. Keep walking afterward: + # SELECT items can contain nested expressions and subqueries + # whose variables must be merged too. + new_children: List[Node] = [] + pending = False + changed = False + for child in node.children: + variable_name: Optional[str] = None + if isinstance(child, ElementVariableNode): + variable_name = child.name + + if variable_name is not None and variable_name in variable_set: + if not pending: + new_children.append(SetVariableNode(set_name)) + pending = True + changed = True + continue + + pending = False + new_children.append(child) + if changed: + node.children = new_children + + if isinstance(node, WhereNode): + if len(node.children) == 1 and isinstance(node.children[0], ElementVariableNode): + if node.children[0].name in variable_set: + node.children = [SetVariableNode(set_name)] + return node + # Otherwise fall through and recurse into children. + + if isinstance(node, JoinNode) and node.on_condition is not None: + oc = node.on_condition + if isinstance(oc, ElementVariableNode) and oc.name in variable_set: + replacement = SetVariableNode(set_name) + node.on_condition = replacement + if len(node.children) > 2: + node.children[2] = replacement + return node + + if isinstance(node, LimitNode) and isinstance(node.limit, ElementVariableNode) and node.limit.name in variable_set: + node.limit = SetVariableNode(set_name) + return node + + if ( + isinstance(node, OperatorNode) + and node.name.lower() == "and" + and not _is_inside_and(parent) + ): + replaced = _process_and_chain(node) + if replaced is not None: + return replaced + + if isinstance(node, JoinNode): + had_on = node.on_condition is not None + n_using = len(node.using) if node.using else 0 + children = getattr(node, "children", None) + if isinstance(children, list): + for idx, child in enumerate(children): + if isinstance(child, Node): + new_child = _visit(child, node) + if new_child is not child: + children[idx] = new_child + RuleGeneratorV2._resync_parallel_attrs(node, child, new_child) + elif isinstance(children, set): + new_set: Set[Node] = set() + replacements: List[Tuple[Node, Node]] = [] + for child in children: + new_child = _visit(child, node) + new_set.add(new_child) + if new_child is not child: + replacements.append((child, new_child)) + node.children = new_set + for old, new in replacements: + RuleGeneratorV2._resync_parallel_attrs(node, old, new) + + if isinstance(node, JoinNode): + RuleGeneratorV2._resync_join_attrs(node, had_on, n_using) + elif isinstance(node, UnaryOperatorNode): + node.operand = node.children[0] + elif isinstance(node, CompoundQueryNode): + node.left = node.children[0] + node.right = node.children[1] + + return node + + return _visit(ast, None) + + @staticmethod + def _replace_node_reference(root: Node, target: Node, replacement: Node) -> None: + """Splice replacement in for target everywhere target appears as a child within root. + + Mutates the tree in place and re-syncs parent attribute aliases via _resync_parallel_attrs. Raises ValueError if target is root itself, since the parent cannot rewire its own pointer. + """ + for node in RuleGeneratorV2._walk(root): + children = getattr(node, "children", None) + replaced_here = False + if isinstance(children, list): + for idx, child in enumerate(children): + if child is target: + children[idx] = replacement + replaced_here = True + elif isinstance(children, set): + if target in children: + children.remove(target) + children.add(replacement) + replaced_here = True + if replaced_here: + RuleGeneratorV2._resync_parallel_attrs(node, target, replacement) + if root is target: + raise ValueError("Cannot replace root node directly; expected nested target.") + + @staticmethod + def _resync_parallel_attrs(node: Node, target: Node, replacement: Node) -> None: + """Rewrite parallel attribute pointers on node (e.g. CaseNode.whens, WhenThenNode.when/then, JoinNode.on_condition) so they reference replacement instead of target. + + Many AST nodes carry named attributes that mirror entries in children; whenever children mutate, these parallel pointers must be re-synced or the formatter will read stale references. + """ + # Many AST nodes mirror children into named attributes (e.g. CaseNode. + # whens / else_val, WhenThenNode.when/then, JoinNode.on_condition). + # The formatter and other helpers read those attrs directly, so + # whenever we mutate children we must keep the parallel pointers in + # sync. Walk the node's __dict__ and substitute any reference that + # is target with replacement. + for attr_name, attr_value in list(node.__dict__.items()): + if attr_name == "children": + continue + if attr_value is target: + setattr(node, attr_name, replacement) + elif isinstance(attr_value, list): + for idx, item in enumerate(attr_value): + if item is target: + attr_value[idx] = replacement + elif isinstance(attr_value, tuple): + if any(item is target for item in attr_value): + setattr( + node, + attr_name, + tuple(replacement if item is target else item for item in attr_value), + ) + + @staticmethod + def _resync_join_attrs(join: JoinNode, had_on: bool, n_using: int) -> None: + """Re-sync JoinNode parallel pointers (left_table, right_table, on_condition, using) from its current children list. + + Caller passes the snapshot of whether the join had an ON clause and how many USING columns existed before the mutation; this method then partitions the post-mutation children accordingly. Mutates join in place. + """ + children = list(join.children) + if len(children) < 2: + return + join.left_table = children[0] # type: ignore[assignment] + join.right_table = children[1] # type: ignore[assignment] + rest = children[2:] + if had_on and rest: + join.on_condition = rest[0] # type: ignore[assignment] + using_rest = rest[1:] + else: + join.on_condition = None + using_rest = rest + if n_using and using_rest: + join.using = list(using_rest[:n_using]) + else: + join.using = None + + @staticmethod + def deparse(node: Node) -> str: + """Render a v2 AST node back into SQL text, including /<> placeholders. + + Wraps a partial node into a full QueryNode for formatting, runs QueryFormatter, fixes mo_sql_parsing's NATURAL JOIN quirk, then strips the synthetic SELECT/FROM/WHERE prefix to recover the original scope. + """ + working = copy.deepcopy(node) + full_query, scope = RuleGeneratorV2._extend_to_full_query(working) + sql = QueryFormatter().format(full_query) + # mo_sql_parsing renders NATURAL JOIN as , NATURAL JOIN() + # with an extra leading comma and no space before the parenthesis. + sql = re.sub(r",\s*NATURAL\s+JOIN\s*\(", " NATURAL JOIN (", sql) + return RuleGeneratorV2._extract_partial_sql(sql, scope) + + @staticmethod + def dereplaceVars(sql: str, mapping: Dict[str, str]) -> str: + """Substitute internal variable names back to user-facing markers (EV001 -> , SV001 -> <>). + + Iterates mapping (external-name -> internal-name) and rewrites every occurrence in sql using the markers from VarTypesInfo. + """ + out = sql + for external_name, internal_name in mapping.items(): + for var_type in VarType: + if internal_name.startswith(VarTypesInfo[var_type]["internalBase"]): + marker_start = VarTypesInfo[var_type]["markerStart"] + marker_end = VarTypesInfo[var_type]["markerEnd"] + out = out.replace(internal_name, f"{marker_start}{external_name}{marker_end}") + break + return out + + @staticmethod + def _extend_to_full_query(node: Node) -> tuple[Node, Scope]: + """Wrap a partial AST node into a full QueryNode so the formatter can render it. + + Returns (full_query, scope) where scope records what part of the synthetic SELECT * FROM t WHERE ... wrapper to strip back off after formatting. + """ + if isinstance(node, CompoundQueryNode): + return node, Scope.SELECT + if isinstance(node, QueryNode): + has_select = RuleGeneratorV2._query_has_clause(node, NodeType.SELECT) + has_from = RuleGeneratorV2._query_has_clause(node, NodeType.FROM) + has_where = RuleGeneratorV2._query_has_clause(node, NodeType.WHERE) + + if has_select: + return node, Scope.SELECT + + if has_from: + return QueryNode( + _select=SelectNode([ColumnNode("*")]), + _from=RuleGeneratorV2._first_clause(node, NodeType.FROM), + _where=RuleGeneratorV2._first_clause(node, NodeType.WHERE), + _group_by=RuleGeneratorV2._first_clause(node, NodeType.GROUP_BY), + _having=RuleGeneratorV2._first_clause(node, NodeType.HAVING), + _order_by=RuleGeneratorV2._first_clause(node, NodeType.ORDER_BY), + _limit=RuleGeneratorV2._first_clause(node, NodeType.LIMIT), + _offset=RuleGeneratorV2._first_clause(node, NodeType.OFFSET), + ), Scope.FROM + + if has_where: + return QueryNode( + _select=SelectNode([ColumnNode("*")]), + _from=FromNode([TableNode("t")]), + _where=RuleGeneratorV2._first_clause(node, NodeType.WHERE), + _group_by=RuleGeneratorV2._first_clause(node, NodeType.GROUP_BY), + _having=RuleGeneratorV2._first_clause(node, NodeType.HAVING), + _order_by=RuleGeneratorV2._first_clause(node, NodeType.ORDER_BY), + _limit=RuleGeneratorV2._first_clause(node, NodeType.LIMIT), + _offset=RuleGeneratorV2._first_clause(node, NodeType.OFFSET), + ), Scope.WHERE + + return QueryNode( + _select=SelectNode([ColumnNode("*")]), + _from=FromNode([TableNode("t")]), + _where=WhereNode([node]), + ), Scope.CONDITION + + @staticmethod + def _first_clause(query: QueryNode, node_type: NodeType) -> Optional[Node]: + """Return the first child of query whose .type matches node_type (or None if absent).""" + for child in query.children: + if child.type == node_type: + return child + return None + + @staticmethod + def _query_has_clause(query: QueryNode, node_type: NodeType) -> bool: + return RuleGeneratorV2._first_clause(query, node_type) is not None + + @staticmethod + def _query_without_clause(query: QueryNode, clause_type: NodeType) -> QueryNode: + return QueryNode( + _select=None if clause_type == NodeType.SELECT else RuleGeneratorV2._first_clause(query, NodeType.SELECT), + _from=None if clause_type == NodeType.FROM else RuleGeneratorV2._first_clause(query, NodeType.FROM), + _where=None if clause_type == NodeType.WHERE else RuleGeneratorV2._first_clause(query, NodeType.WHERE), + _group_by=None if clause_type == NodeType.GROUP_BY else RuleGeneratorV2._first_clause(query, NodeType.GROUP_BY), + _having=None if clause_type == NodeType.HAVING else RuleGeneratorV2._first_clause(query, NodeType.HAVING), + _order_by=None if clause_type == NodeType.ORDER_BY else RuleGeneratorV2._first_clause(query, NodeType.ORDER_BY), + _limit=None if clause_type == NodeType.LIMIT else RuleGeneratorV2._first_clause(query, NodeType.LIMIT), + _offset=None if clause_type == NodeType.OFFSET else RuleGeneratorV2._first_clause(query, NodeType.OFFSET), + ) + + @staticmethod + def _extract_partial_sql(full_sql: str, scope: Scope) -> str: + if scope == Scope.SELECT: + return full_sql + if scope == Scope.FROM: + return full_sql.replace("SELECT * ", "", 1) + if scope == Scope.WHERE: + return full_sql.replace("SELECT * FROM t ", "", 1) + return full_sql.replace("SELECT * FROM t WHERE ", "", 1) + + + + @staticmethod + def fingerPrint(rule: RuleV2) -> str: + """Return a stable fingerprint string for rule based on its deparsed pattern. + + Variable indices are normalized so that two rules that differ only in variable numbering share a fingerprint. Used to deduplicate rules in the generalization graph. + """ + ast = rule.get("pattern_ast") + if not isinstance(ast, Node): + raise TypeError("rule['pattern_ast'] must be an AST Node") + pattern = RuleGeneratorV2.deparse(ast) + return RuleGeneratorV2._fingerPrint(pattern) + + @staticmethod + def _fingerPrint(fingerprint: str) -> str: + out = fingerprint + out = re.sub(r"'()'", r"\1", out) + out = re.sub(r"", "", out) + out = re.sub(r"<>", "<>", out) + out = re.sub(r"''", "''", out) + out = re.sub(r"", "", out) + out = re.sub(r"<>", "<>", out) + return out + + @staticmethod + def numberOfVariables(rule: RuleV2) -> int: + """Return the count of declared variables in rule['mapping']. + + Used as a tie-breaker when picking the simplest rule among equivalents. + """ + mapping = rule.get("mapping") + if not isinstance(mapping, dict): + raise TypeError("rule['mapping'] must be a dict[str, str]") + return len(mapping.keys()) + + @staticmethod + def unify_variable_names(q0: str, q1: str) -> Tuple[str, str]: + """Renumber /<> placeholders in q0 and q1 consecutively in order of first appearance. + + Returns the rewritten pair (q0', q1'); e.g. and become and so two rules with equivalent placeholders compare equal. + """ + mapping: Dict[str, str] = {} + counter = 1 + + for token in re.findall(r"<<\w+>>|<\w+>", q0 + " " + q1): + if token not in mapping: + mapping[token] = f"<>" if token.startswith("<<") else f"" + counter += 1 + + def _replace(text: str) -> str: + return re.sub(r"<<\w+>>|<\w+>", lambda m: mapping.get(m.group(), m.group()), text) + + return _replace(q0), _replace(q1) + + @staticmethod + def _find_next_element_variable(mapping: Dict[str, str]) -> Tuple[Dict[str, str], str]: + """Allocate the next unused element variable in mapping and return (updated_mapping, external_name). + + Mutates mapping in place by inserting the new x? -> EV??? entry. + """ + max_external = 0 + max_internal = 0 + for external_name, internal_name in mapping.items(): + external_num = RuleGeneratorV2._suffix_int(external_name, "x") + if external_num is not None: + max_external = max(max_external, external_num) + internal_num = RuleGeneratorV2._suffix_int(internal_name, "EV") + if internal_num is not None: + max_internal = max(max_internal, internal_num) + + next_external = f"x{max_external + 1}" + next_internal = f"EV{str(max_internal + 1).zfill(3)}" + mapping[next_external] = next_internal + return mapping, next_external + + @staticmethod + def _find_next_set_variable(mapping: Dict[str, str]) -> Tuple[Dict[str, str], str]: + """Allocate the next unused set variable in mapping and return (updated_mapping, set_name). + + Mutates mapping in place by inserting the new y? -> SV??? entry. + """ + max_external = 0 + max_internal = 0 + for external_name, internal_name in mapping.items(): + external_num = RuleGeneratorV2._suffix_int(external_name, "y") + if external_num is not None: + max_external = max(max_external, external_num) + internal_num = RuleGeneratorV2._suffix_int(internal_name, "SV") + if internal_num is not None: + max_internal = max(max_internal, internal_num) + + next_external = f"y{max_external + 1}" + next_internal = f"SV{str(max_internal + 1).zfill(3)}" + mapping[next_external] = next_internal + return mapping, next_external + + @staticmethod + def _suffix_int(value: str, prefix: str) -> Optional[int]: + if not value.lower().startswith(prefix.lower()): + return None + suffix = value[len(prefix):] + if not suffix or not suffix.isdigit(): + return None + return int(suffix) + + @staticmethod + def _walk(node: Optional[Node]) -> Iterator[Node]: + """Pre-order yield every Node in the subtree rooted at node (including the node itself). + + Safe to call with None; non-Node children and missing children attributes are skipped. + """ + if node is None: + return + yield node + children = getattr(node, "children", None) + if not children: + return + for child in children: + yield from RuleGeneratorV2._walk(child) diff --git a/core/rule_parser_v2.py b/core/rule_parser_v2.py index 1159b13..3824746 100644 --- a/core/rule_parser_v2.py +++ b/core/rule_parser_v2.py @@ -27,6 +27,7 @@ OperatorNode, OrderByItemNode, OrderByNode, + CompoundQueryNode, QueryNode, SelectNode, SubqueryNode, @@ -34,6 +35,7 @@ UnaryOperatorNode, ElementVariableNode, SetVariableNode, + VariableLiteralNode, WhenThenNode, WhereNode, ) @@ -262,9 +264,13 @@ def _substitute_rule_vars( # @staticmethod def _extract_rule_fragment(query: Node, scope: Scope) -> Node: - # CompoundQueryNode (e.g. UNION) is always a full-query scope — return as-is if isinstance(query, CompoundQueryNode): + if scope != Scope.SELECT: + raise ValueError("Non-SELECT fragment scope is not supported for compound queries") return query + if not isinstance(query, QueryNode): + raise TypeError("expected QueryNode or CompoundQueryNode while extracting rule fragment") + frm = RuleParserV2._get_clause(query, NodeType.FROM) wh = RuleParserV2._get_clause(query, NodeType.WHERE) gb = RuleParserV2._get_clause(query, NodeType.GROUP_BY) @@ -345,31 +351,54 @@ def _replace_internal_in_string(s: str) -> str: return node pa = col.parent_alias nm = col.name - new_alias = _replace_internal_in_string(col.alias) if isinstance(col.alias, str) else col.alias + if isinstance(col.alias, str) and col.alias in rev: + new_alias: Optional[Union[str, ElementVariableNode]] = ElementVariableNode(rev[col.alias]) + elif isinstance(col.alias, str): + new_alias = _replace_internal_in_string(col.alias) + else: + new_alias = col.alias new_pa = _replace_internal_in_string(pa) if isinstance(pa, str) else pa + + # Bare column variable (no qualifier): promote to ElementVariableNode if pa is None and nm in rev: return RuleParserV2._placeholder_varnode(nm, rev[nm]) + + # Both name and parent_alias are variables if pa is not None and pa in rev and nm in rev: - return ColumnNode(rev[nm], _alias=new_alias, _parent_alias=rev[pa]) + return ElementVariableNode(rev[nm], parent_alias=ElementVariableNode(rev[pa]), alias=new_alias) + + # Only parent_alias is a variable (concrete column, variable table qualifier) if pa is not None and pa in rev: - return ColumnNode(nm, _alias=new_alias, _parent_alias=rev[pa]) + return ColumnNode(nm, _alias=new_alias, _parent_alias=ElementVariableNode(rev[pa])) + + # Only column name is a variable (concrete table qualifier) if pa is not None and nm in rev: - return ColumnNode(rev[nm], _alias=new_alias, _parent_alias=new_pa) + return ElementVariableNode(rev[nm], parent_alias=new_pa, alias=new_alias) + return ColumnNode(nm, _alias=new_alias, _parent_alias=new_pa) if node.type == NodeType.TABLE: t = node if not isinstance(t, TableNode): return node - # If table name is a SET variable placeholder (<>), promote to SetVariableNode - # so it matches any table or list of tables in the FROM clause. - # Element variable tokens (EV...) stay as TableNode so _match_node handles them. sv_base = VarTypesInfo[VarType.SetVariable]["internalBase"] + ev_base = VarTypesInfo[VarType.ElementVariable]["internalBase"] + + # SET variable table: promote to SetVariableNode if isinstance(t.name, str) and t.name in rev and t.name.startswith(sv_base): return SetVariableNode(rev[t.name]) + + # ELEMENT variable table: promote to ElementVariableNode + if isinstance(t.name, str) and t.name in rev and t.name.startswith(ev_base): + # alias may also be a variable + if t.alias is not None and isinstance(t.alias, str) and t.alias in rev: + return ElementVariableNode(rev[t.name], alias=ElementVariableNode(rev[t.alias])) + return ElementVariableNode(rev[t.name]) + + # Concrete table new_name = rev.get(t.name, t.name) if isinstance(t.name, str) else t.name if t.alias is not None and isinstance(t.alias, str) and t.alias in rev: - new_alias = rev[t.alias] + new_alias = ElementVariableNode(rev[t.alias]) else: new_alias = t.alias return TableNode(new_name, new_alias) @@ -378,13 +407,19 @@ def _replace_internal_in_string(s: str) -> str: lit = node if not isinstance(lit, LiteralNode): return node + alias = _replace_internal_in_string(lit.alias) if isinstance(getattr(lit, "alias", None), str) else getattr(lit, "alias", None) if isinstance(lit.value, str): - # If the entire literal value is an internal placeholder token, promote to var node + # Exact match: the entire literal is a placeholder token → variable literal if lit.value in rev: - return LiteralNode(rev[lit.value]) - # Otherwise substitute any embedded tokens (e.g. '%EV001%' to '%x%') - return LiteralNode(_replace_internal_in_string(lit.value)) - return LiteralNode(lit.value) + return VariableLiteralNode(rev[lit.value], _alias=alias) + # Embedded token: e.g. '%EV001%' → VariableLiteralNode with surrounding wildcards + stripped = lit.value.replace("%", "") + if stripped in rev: + prefix = "%" if lit.value.startswith("%") else "" + suffix = "%" if lit.value.endswith("%") else "" + return VariableLiteralNode(rev[stripped], prefix=prefix, suffix=suffix, _alias=alias) + return LiteralNode(_replace_internal_in_string(lit.value), _alias=alias) + return LiteralNode(lit.value, _alias=alias) if node.type == NodeType.QUERY: q = node @@ -401,6 +436,14 @@ def _replace_internal_in_string(s: str) -> str: _offset=RuleParserV2._as_rule_ast(RuleParserV2._get_clause(q, NodeType.OFFSET), rev), ) + if node.type == NodeType.COMPOUND_QUERY: + cq = node + if not isinstance(cq, CompoundQueryNode): + return node + left = RuleParserV2._substitute_placeholders(cq.left, rev) + right = RuleParserV2._substitute_placeholders(cq.right, rev) + return CompoundQueryNode(left, right, cq.is_all) + if node.type == NodeType.SELECT: sn = node if not isinstance(sn, SelectNode): @@ -451,6 +494,8 @@ def _replace_internal_in_string(s: str) -> str: if not isinstance(lim, LimitNode): return node if isinstance(lim.limit, str): + if lim.limit in rev: + return LimitNode(ElementVariableNode(rev[lim.limit])) return LimitNode(_replace_internal_in_string(lim.limit)) return LimitNode(lim.limit) @@ -459,6 +504,8 @@ def _replace_internal_in_string(s: str) -> str: if not isinstance(off, OffsetNode): return node if isinstance(off.offset, str): + if off.offset in rev: + return OffsetNode(ElementVariableNode(rev[off.offset])) return OffsetNode(_replace_internal_in_string(off.offset)) return OffsetNode(off.offset) @@ -473,13 +520,19 @@ def _replace_internal_in_string(s: str) -> str: j = node if not isinstance(j, JoinNode): return node - ch = list(j.children) - left = RuleParserV2._substitute_placeholders(ch[0], rev) - right = RuleParserV2._substitute_placeholders(ch[1], rev) + left = RuleParserV2._substitute_placeholders(j.left_table, rev) + right = RuleParserV2._substitute_placeholders(j.right_table, rev) on_expr = ( - RuleParserV2._substitute_placeholders(ch[2], rev) if len(ch) > 2 else None + RuleParserV2._substitute_placeholders(j.on_condition, rev) + if j.on_condition is not None + else None ) - return JoinNode(left, right, j.join_type, on_expr) + using_cols = ( + [RuleParserV2._substitute_placeholders(c, rev) for c in j.using] + if j.using + else None + ) + return JoinNode(left, right, j.join_type, on_expr, using_cols) if node.type == NodeType.SUBQUERY: sq = node @@ -494,7 +547,12 @@ def _replace_internal_in_string(s: str) -> str: if not isinstance(f, FunctionNode): return node new_args = [RuleParserV2._substitute_placeholders(a, rev) for a in f.children] - alias = _replace_internal_in_string(f.alias) if isinstance(f.alias, str) else f.alias + if isinstance(f.alias, str) and f.alias in rev: + alias = ElementVariableNode(rev[f.alias]) + elif isinstance(f.alias, str): + alias = _replace_internal_in_string(f.alias) + else: + alias = f.alias return FunctionNode(f.name, _args=new_args, _alias=alias) if node.type == NodeType.LIST: diff --git a/data/rules.py b/data/rules.py index c99a6ff..f2bfad3 100644 --- a/data/rules.py +++ b/data/rules.py @@ -2,6 +2,7 @@ from core.rule_parser import RuleParser from core.rule_parser_v2 import RuleParserV2 +from core.rule import RuleV2 rules = [ # PostgresSQL Rules @@ -764,28 +765,26 @@ def get_rule(key: str) -> dict: # fetch one rule by key using the v2 AST-based parser # -def get_rule_v2(key: str) -> dict: +def get_rule_v2(key: str) -> RuleV2: rule = next(filter(lambda x: x['key'] == key, rules), None) if rule is None: raise ValueError(f"Rule {key} not found") result = RuleParserV2.parse(rule['pattern'], rule['rewrite']) - # TODO: reuse v1 parse_actions? + # Action variables stay as external names (s1/t2/...) since match() binds those + # in memo, so parse with an identity mapping rather than external->internal. identity_mapping = json.dumps({k: k for k in result.mapping}) - actions_json = RuleParser.parse_actions(rule['actions'], identity_mapping) - return { - 'id': rule['id'], - 'key': rule['key'], - 'name': rule['name'], - 'pattern': rule['pattern'], - 'pattern_ast': result.pattern_ast, - 'rewrite': rule['rewrite'], - 'rewrite_ast': result.rewrite_ast, - 'mapping': result.mapping, - 'actions': rule['actions'], - 'actions_json': json.loads(actions_json), - 'database': rule['database'], - 'examples': rule['examples'], - } + actions_json = json.loads(RuleParser.parse_actions(rule['actions'], identity_mapping)) + return RuleV2( + id=rule['id'], + key=rule['key'], + pattern=rule['pattern'], + pattern_ast=result.pattern_ast, + rewrite=rule['rewrite'], + rewrite_ast=result.rewrite_ast, + mapping=result.mapping, + actions=rule['actions'], + actions_json=actions_json, + ) # return a list of rules (json attributes are in str) diff --git a/tests/test_rule_generator_v2.py b/tests/test_rule_generator_v2.py new file mode 100644 index 0000000..13b3096 --- /dev/null +++ b/tests/test_rule_generator_v2.py @@ -0,0 +1,2124 @@ +from __future__ import annotations + +import pytest + +from core.ast.enums import NodeType +from core.ast.node import QueryNode +from core.query_formatter import QueryFormatter +from core.query_parser import QueryParser +from core.rule_generator_v2 import RuleGeneratorV2 +from core.rule import RuleV2 +from core.rule_parser_v2 import RuleParserV2, VarType +from data.rules import get_rule_v2 as get_rule + + +def _build_rule(pattern: str, rewrite: str) -> RuleV2: + parsed = RuleParserV2.parse(pattern, rewrite) + return RuleV2( + pattern=pattern, + rewrite=rewrite, + pattern_ast=parsed.pattern_ast, + rewrite_ast=parsed.rewrite_ast, + mapping=parsed.mapping, + ) + + +def _has_clause(query: QueryNode, clause_type: NodeType) -> bool: + return any(child.type == clause_type for child in query.children) + + +def _norm_sql(sql: str) -> str: + return " ".join(sql.split()) + + +_PARSER = QueryParser() +_FORMATTER = QueryFormatter() + + +def parse(query: str): + return _PARSER.parse(query.strip()) + + +def format(ast): + return _FORMATTER.format(ast) + + +def _assert_matches_expected( + q0: str, q1: str, expected_pattern: str, expected_rewrite: str +) -> None: + """Compare v2 output against a hand-written expected pattern/rewrite. + + Both sides are normalized through unify_variable_names so concrete + placeholder names do not need to align. + """ + parse(q0) + parse(q1) + rule_v2 = RuleGeneratorV2.generate_general_rule(q0, q1) + got_p, got_r = RuleGeneratorV2.unify_variable_names(rule_v2["pattern"], rule_v2["rewrite"]) + exp_p, exp_r = RuleGeneratorV2.unify_variable_names(expected_pattern, expected_rewrite) + assert _norm_sql(got_p) == _norm_sql(exp_p) + assert _norm_sql(got_r) == _norm_sql(exp_r) + + +def _assert_matches_rule(q0: str, q1: str, key: str) -> None: + rule = get_rule(key) + assert rule is not None + _assert_matches_expected(q0, q1, rule["pattern"], rule["rewrite"]) + + + + +def test_dereplaceVars_mixed_element_and_set_vars(): + pattern = """ + select SV001 + from EV001 EV002, + EV003 EV004 + where EV002.EV005=EV004.EV006 + and SV002 + """ + mapping = { + "x1": "EV001", + "y1": "SV001", + "x2": "EV002", + "y2": "SV002", + "x3": "EV003", + "x4": "EV004", + "x5": "EV005", + "x6": "EV006", + } + + dereplaced = RuleGeneratorV2.dereplaceVars(pattern, mapping) + assert "<>" in dereplaced + assert ".=." in dereplaced + assert "<>" in dereplaced + + +def test_dereplaceVars_1(): + assert RuleGeneratorV2.dereplaceVars("CAST(EV001 AS DATE)", {"x": "EV001"}) == "CAST( AS DATE)" + assert RuleGeneratorV2.dereplaceVars("EV001", {"x": "EV001"}) == "" + + +def test_dereplaceVars_2(): + pattern = """ + select SV001 + from EV001 EV002, + EV003 EV004 + where EV002.EV005=EV004.EV006 + and SV002 + """ + rewrite = """ + select SV001 + from EV001 EV002 + where SV002 + """ + mapping = { + "x1": "EV001", + "y1": "SV001", + "x2": "EV002", + "y2": "SV002", + "x3": "EV003", + "x4": "EV004", + "x5": "EV005", + "x6": "EV006", + } + assert RuleGeneratorV2.dereplaceVars(pattern, mapping) == """ + select <> + from , + + where .=. + and <> + """ + assert RuleGeneratorV2.dereplaceVars(rewrite, mapping) == """ + select <> + from + where <> + """ + + +def test_deparse_condition_scope_expression(): + result = RuleParserV2.parse("CAST( AS DATE)", "") + assert RuleGeneratorV2.deparse(result.pattern_ast) == "CAST( AS DATE)" + assert RuleGeneratorV2.deparse(result.rewrite_ast) == "" + + +def test_deparse_1(): + result = RuleParserV2.parse("CAST(V1 AS DATE)", "V1") + assert RuleGeneratorV2.deparse(result.pattern_ast) == "CAST(V1 AS DATE)" + assert RuleGeneratorV2.deparse(result.rewrite_ast) == "V1" + + +def test_deparse_2(): + result = RuleParserV2.parse("STRPOS(LOWER(V1), 'V2') > 0", "V1 ILIKE '%V2%'") + assert RuleGeneratorV2.deparse(result.pattern_ast) == "STRPOS(LOWER(V1), 'V2') > 0" + assert RuleGeneratorV2.deparse(result.rewrite_ast) == "V1 ILIKE '%V2%'" + + + +def test_columns_1(): + result = RuleParserV2.parse("STRPOS(LOWER(text), 'iphone') > 0", "ILIKE(text, '%iphone%')") + assert set(RuleGeneratorV2.columns(result.pattern_ast, result.rewrite_ast)) == {"text"} + + +def test_columns_2(): + result = RuleParserV2.parse("CAST(state_name AS TEXT)", "state_name") + assert set(RuleGeneratorV2.columns(result.pattern_ast, result.rewrite_ast)) == {"state_name"} + + +def test_columns_excludes_variable_placeholders(): + result = RuleParserV2.parse( + """ + select e1.name, e1.age, e2.salary + from employee e1, employee e2 + where e1. = e2. + and e1.age > 17 + and e2.salary > 35000 + """, + """ + select e1.name, e1.age, e1.salary + from employee e1 + where e1.age > 17 + and e1.salary > 35000 + """, + ) + columns = RuleGeneratorV2.columns(result.pattern_ast, result.rewrite_ast) + assert set(columns) == {"name", "age", "salary"} + + +def test_columns_4(): + result = RuleParserV2.parse( + """ + select e1.name, e1.age, e2.salary + from employee e1, + employee e2 + where e1. = e2. + and e1.age > 17 + and e2.salary > 35000; + """, + """ + SELECT e1.name, + e1.age, + e1.salary + FROM employee e1 + WHERE e1.age > 17 + AND e1.salary > 35000; + """, + ) + assert set(RuleGeneratorV2.columns(result.pattern_ast, result.rewrite_ast)) == {"name", "age", "salary"} + + +def test_columns_3(): + result = RuleParserV2.parse( + """ + select e1.name, e1.age, e2.salary + from employee e1, + employee e2 + where e1.id = e2.id + and e1.age > 17 + and e2.salary > 35000; + """, + """ + SELECT e1.name, + e1.age, + e1.salary + FROM employee e1 + WHERE e1.age > 17 + AND e1.salary > 35000; + """, + ) + assert set(RuleGeneratorV2.columns(result.pattern_ast, result.rewrite_ast)) == {"name", "age", "salary", "id"} + + +def test_columns_5(): + result = RuleParserV2.parse( + """ + select e1.* + from employee e1, + employee e2 + where e1.id = e2.id + and e1.age > 17 + and e2.salary > 35000; + """, + """ + SELECT e1.* + FROM employee e1 + WHERE e1.age > 17 + AND e1.salary > 35000; + """, + ) + assert set(RuleGeneratorV2.columns(result.pattern_ast, result.rewrite_ast)) == {"*", "id", "age", "salary"} + + +def test_columns_6(): + result = RuleParserV2.parse( + """ + select * + from employee + where workdept in + (select deptno + from department + where deptname = 'OPERATIONS'); + """, + """ + select distinct * + from employee emp, department dept + where emp.workdept = dept.deptno + and dept.deptname = 'OPERATIONS'; + """, + ) + assert set(RuleGeneratorV2.columns(result.pattern_ast, result.rewrite_ast)) == {"*", "workdept", "deptno", "deptname"} + + +def test_columns_7(): + result = RuleParserV2.parse( + """ + SELECT * + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + INNER JOIN blc_admin_role adminrolei2_ + ON allroles1_.admin_role_id = adminrolei2_.admin_role_id + WHERE adminrolei2_.admin_role_id = 1 + """, + """ + SELECT * + FROM blc_admin_permission AS adminpermi0_ + INNER JOIN blc_admin_role_permission_xref AS allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + WHERE allroles1_.admin_role_id = 1 + """, + ) + assert set(RuleGeneratorV2.columns(result.pattern_ast, result.rewrite_ast)) == {"*", "admin_permission_id", "admin_role_id"} + + +def test_columns_8(): + result = RuleParserV2.parse( + """ + SELECT adminpermi0_.admin_permission_id AS admin_pe1_4_, + adminpermi0_.description AS descript2_4_, + adminpermi0_.is_friendly AS is_frien3_4_, + adminpermi0_.name AS name4_4_, + adminpermi0_.permission_type AS permissi5_4_ + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + INNER JOIN blc_admin_role adminrolei2_ + ON allroles1_.admin_role_id = adminrolei2_.admin_role_id + WHERE adminpermi0_.is_friendly = 1 + AND adminrolei2_.admin_role_id = 1 + ORDER BY adminpermi0_.description ASC + LIMIT 50 + """, + """ + SELECT adminpermi0_.admin_permission_id AS admin_pe1_4_, + adminpermi0_.description AS descript2_4_, + adminpermi0_.is_friendly AS is_frien3_4_, + adminpermi0_.name AS name4_4_, + adminpermi0_.permission_type AS permissi5_4_ + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + WHERE adminpermi0_.is_friendly = 1 + AND allroles1_.admin_role_id = 1 + ORDER BY adminpermi0_.description ASC + LIMIT 50 + """, + ) + assert set(RuleGeneratorV2.columns(result.pattern_ast, result.rewrite_ast)) == { + "admin_permission_id", + "description", + "is_friendly", + "name", + "permission_type", + "admin_role_id", + } + + +def test_literals_1(): + result = RuleParserV2.parse("STRPOS(LOWER(text), 'iphone') > 0", "ILIKE(text, '%iphone%')") + assert set(RuleGeneratorV2.literals(result.pattern_ast, result.rewrite_ast)) == {"iphone"} + + +def test_literals_2(): + result = RuleParserV2.parse( + """ + select e1.name, e1.age, e2.salary + from employee e1, + employee e2 + where e1.id = e2.id + and e1.age > 17 + and e2.salary > 35000; + """, + """ + SELECT e1.name, + e1.age, + e1.salary + FROM employee e1 + WHERE e1.age > 17 + AND e1.salary > 35000; + """, + ) + assert set(RuleGeneratorV2.literals(result.pattern_ast, result.rewrite_ast)) == {17, 35000} + + +def test_literals_3(): + result = RuleParserV2.parse( + """ + SELECT * + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + INNER JOIN blc_admin_role adminrolei2_ + ON allroles1_.admin_role_id = adminrolei2_.admin_role_id + WHERE adminrolei2_.admin_role_id = 1 + """, + """ + SELECT * + FROM blc_admin_permission AS adminpermi0_ + INNER JOIN blc_admin_role_permission_xref AS allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + WHERE allroles1_.admin_role_id = 1 + """, + ) + assert set(RuleGeneratorV2.literals(result.pattern_ast, result.rewrite_ast)) == {1} + + +def test_tables_1(): + result = RuleParserV2.parse("STRPOS(LOWER(text), 'iphone') > 0", "ILIKE(text, '%iphone%')") + assert RuleGeneratorV2.tables(result.pattern_ast, result.rewrite_ast) == [] + + +def test_tables_2(): + result = RuleParserV2.parse( + """ + select e1.name, e1.age, e2.salary + from employee e1, + employee e2 + where e1.id = e2.id + and e1.age > 17 + and e2.salary > 35000; + """, + """ + SELECT e1.name, + e1.age, + e1.salary + FROM employee e1 + WHERE e1.age > 17 + AND e1.salary > 35000; + """, + ) + expected = {("employee", "e1"), ("employee", "e2")} + actual = {(t["value"], t["name"]) for t in RuleGeneratorV2.tables(result.pattern_ast, result.rewrite_ast)} + assert actual == expected + + +def test_tables_3(): + result = RuleParserV2.parse( + """ + select .name, .age, .salary + from , + where . = . + and .age > 17 + and .salary > 35000; + """, + """ + SELECT .name, .age, .salary + FROM + WHERE .age > 17 + AND .salary > 35000; + """, + ) + assert RuleGeneratorV2.tables(result.pattern_ast, result.rewrite_ast) == [] + + +def test_tables_3_excludes_variable_tables(): + result = RuleParserV2.parse( + """ + select .name, .age, .salary + from , + where . = . + and .age > 17 + and .salary > 35000 + """, + """ + select .name, .age, .salary + from + where .age > 17 + and .salary > 35000 + """, + ) + assert RuleGeneratorV2.tables(result.pattern_ast, result.rewrite_ast) == [] + + +def test_tables_4(): + result = RuleParserV2.parse( + """ + select * + from employee + where workdept in + (select deptno + from department + where deptname = 'OPERATIONS'); + """, + """ + select distinct * + from employee, department + where employee.workdept = department.deptno + and department.deptname = 'OPERATIONS'; + """, + ) + actual = {(t["value"], t["name"]) for t in RuleGeneratorV2.tables(result.pattern_ast, result.rewrite_ast)} + assert actual == {("employee", "employee"), ("department", "department")} + + +def test_tables_4_subquery_tables(): + result = RuleParserV2.parse( + """ + select * + from employee + where workdept in ( + select deptno + from department + where deptname = 'OPERATIONS' + ) + """, + """ + select distinct * + from employee, department + where employee.workdept = department.deptno + and department.deptname = 'OPERATIONS' + """, + ) + expected = {("employee", "employee"), ("department", "department")} + actual = {(t["value"], t["name"]) for t in RuleGeneratorV2.tables(result.pattern_ast, result.rewrite_ast)} + assert actual == expected + + +def test_tables_5(): + result = RuleParserV2.parse( + """ + SELECT * + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + INNER JOIN blc_admin_role adminrolei2_ + ON allroles1_.admin_role_id = adminrolei2_.admin_role_id + WHERE adminrolei2_.admin_role_id = 1 + """, + """ + SELECT * + FROM blc_admin_permission AS adminpermi0_ + INNER JOIN blc_admin_role_permission_xref AS allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + WHERE allroles1_.admin_role_id = 1 + """, + ) + actual = {(t["value"], t["name"]) for t in RuleGeneratorV2.tables(result.pattern_ast, result.rewrite_ast)} + assert actual == { + ("blc_admin_permission", "adminpermi0_"), + ("blc_admin_role_permission_xref", "allroles1_"), + ("blc_admin_role", "adminrolei2_"), + } + + +def test_tables_6(): + result = RuleParserV2.parse( + """ + SELECT Count(*) + FROM (SELECT 1 AS one + FROM group_histories + WHERE group_histories.group_id = 2578 + AND group_histories.action = 2 + ORDER BY group_histories.created_at DESC + LIMIT 25 offset 0) subquery_for_count + """, + """ + SELECT Count(*) + FROM (SELECT 1 AS one + FROM group_histories + WHERE group_histories.group_id = 2578 + AND group_histories.action = 2 + LIMIT 25 offset 0) AS subquery_for_count + """, + ) + actual = {(t["value"], t["name"]) for t in RuleGeneratorV2.tables(result.pattern_ast, result.rewrite_ast)} + assert actual == {("group_histories", "group_histories")} + + +def test_variablize_literal_1(): + rule = _build_rule("STRPOS(LOWER(text), 'iphone') > 0", "text ILIKE '%iphone%'") + out = RuleGeneratorV2.variablize_literal(rule, "iphone") + assert out["pattern"] == "STRPOS(LOWER(text), '') > 0" + assert out["rewrite"] == "text ILIKE '%%'" + + +def test_variablize_literal_2(): + rule = _build_rule( + """ + select e1.name, e1.age, e2.salary + from employee e1, + employee e2 + where e1.id = e2.id + and e1.age > 17 + and e2.salary > 35000 + """, + """ + SELECT e1.name, e1.age, e1.salary + FROM employee e1 + WHERE e1.age > 17 + AND e1.salary > 35000 + """, + ) + out = RuleGeneratorV2.variablize_literal(rule, 17) + assert "e1.age > " in out["pattern"] + assert "e1.age > " in out["rewrite"] + + +def test_variablize_column_1(): + rule = _build_rule("CAST(created_at AS DATE)", "created_at") + out = RuleGeneratorV2.variablize_column(rule, "created_at") + assert out["pattern"] == "CAST( AS DATE)" + assert out["rewrite"] == "" + + +def test_variablize_column_2(): + rule = _build_rule("STRPOS(LOWER(text), 'iphone') > 0", "text ILIKE '%iphone%'") + out = RuleGeneratorV2.variablize_column(rule, "text") + assert out["pattern"] == "STRPOS(LOWER(), 'iphone') > 0" + assert out["rewrite"] == " ILIKE '%iphone%'" + + +def test_variablize_column_3(): + rule = _build_rule( + """ + select e1.name, e1.age, e2.salary + from employee e1, + employee e2 + where e1.id = e2.id + and e1.age > 17 + and e2.salary > 35000 + """, + """ + SELECT e1.name, e1.age, e1.salary + FROM employee e1 + WHERE e1.age > 17 + AND e1.salary > 35000 + """, + ) + out = RuleGeneratorV2.variablize_column(rule, "id") + assert _norm_sql(out["pattern"]) == _norm_sql( + "SELECT e1.name, e1.age, e2.salary FROM employee AS e1, employee AS e2 WHERE e1. = e2. AND e1.age > 17 AND e2.salary > 35000" + ) + assert _norm_sql(out["rewrite"]) == _norm_sql( + "SELECT e1.name, e1.age, e1.salary FROM employee AS e1 WHERE e1.age > 17 AND e1.salary > 35000" + ) + + +def test_variablize_column_4(): + rule = _build_rule( + """ + select * + from employee + where workdept in + (select deptno + from department + where deptname = 'OPERATIONS'); + """, + """ + select distinct * + from employee emp, department dept + where emp.workdept = dept.deptno + and dept.deptname = 'OPERATIONS'; + """, + ) + out = RuleGeneratorV2.variablize_column(rule, "*") + assert _norm_sql(out["pattern"]) == _norm_sql( + "SELECT FROM employee WHERE workdept IN (SELECT deptno FROM department WHERE deptname = 'OPERATIONS')" + ) + assert _norm_sql(out["rewrite"]) == _norm_sql( + "SELECT DISTINCT FROM employee AS emp, department AS dept WHERE emp.workdept = dept.deptno AND dept.deptname = 'OPERATIONS'" + ) + + +def test_variablize_table_1(): + rule = _build_rule( + """ + select e1.name, e1.age, e2.salary + from employee e1, + employee e2 + where e1.id = e2.id + and e1.age > 17 + and e2.salary > 35000 + """, + """ + SELECT e1.name, e1.age, e1.salary + FROM employee e1 + WHERE e1.age > 17 + AND e1.salary > 35000 + """, + ) + out = RuleGeneratorV2.variablize_table(rule, {"value": "employee", "name": "e1"}) + assert "FROM , employee AS e2" in out["pattern"] or "FROM , employee e2" in out["pattern"] + assert ".id = e2.id" in out["pattern"] + assert ("FROM " in out["rewrite"]) or ("FROM x1" in out["rewrite"]) + + +def test_variablize_table_2(): + rule = _build_rule( + """ + SELECT .name, .age, e2.salary + FROM , employee AS e2 + WHERE .id = e2.id + AND .age > 17 + AND e2.salary > 35000 + """, + """ + SELECT .name, .age, .salary + FROM + WHERE .age > 17 + AND .salary > 35000 + """, + ) + out = RuleGeneratorV2.variablize_table(rule, {"value": "employee", "name": "e2"}) + assert "FROM , " in out["pattern"] + assert ".id = .id" in out["pattern"] + assert ".salary > 35000" in out["pattern"] + assert "FROM " in out["rewrite"] + + +def test_variablize_table_3(): + rule = _build_rule( + """ + SELECT Count(adminpermi0_.admin_permission_id) AS col_0_0_ + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + INNER JOIN blc_admin_role adminrolei2_ + ON allroles1_.admin_role_id = adminrolei2_.admin_role_id + WHERE adminpermi0_.is_friendly = 1 + AND adminrolei2_.admin_role_id = 1 + """, + """ + SELECT Count(adminpermi0_.admin_permission_id) AS col_0_0_ + FROM blc_admin_permission AS adminpermi0_ + INNER JOIN blc_admin_role_permission_xref AS allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + WHERE allroles1_.admin_role_id = 1 + AND adminpermi0_.is_friendly = 1 + """, + ) + out = RuleGeneratorV2.variablize_table( + rule, {"value": "blc_admin_permission", "name": "adminpermi0_"} + ) + assert "FROM " in out["pattern"] + assert "JOIN blc_admin_role_permission_xref AS allroles1_" in out["pattern"] + assert ".admin_permission_id = allroles1_.admin_permission_id" in out["pattern"] + assert ".is_friendly = 1" in out["pattern"] + assert "FROM " in out["rewrite"] + + +def test_subtrees_1(): + result = RuleParserV2.parse("STRPOS(LOWER(text), 'iphone') > 0", "text ILIKE '%iphone%'") + assert RuleGeneratorV2.subtrees(result.pattern_ast, result.rewrite_ast) == [] + + +def test_subtrees_2(): + result = RuleParserV2.parse( + """ + select e1.name, e1.age, e2.salary + from employee e1, + employee e2 + where e1.id = e2.id + and e1.age > 17 + and e2.salary > 35000; + """, + """ + SELECT e1.name, e1.age, e1.salary + FROM employee e1 + WHERE e1.age > 17 + AND e1.salary > 35000; + """, + ) + assert RuleGeneratorV2.subtrees(result.pattern_ast, result.rewrite_ast) == [] + + +def test_subtrees_3(): + result = RuleParserV2.parse( + """ + select .name, .age, .salary + from , + where . = . + and .age > 17 + and .salary > 35000; + """, + """ + SELECT .name, .age, .salary + FROM + WHERE .age > 17 + AND .salary > 35000; + """, + ) + assert RuleGeneratorV2.subtrees(result.pattern_ast, result.rewrite_ast) == [] + + +def test_subtrees_4(): + result = RuleParserV2.parse( + """ + select ., .age, .salary + from , + where . = . + and .age > 17 + and .salary > 35000; + """, + """ + SELECT ., .age, .salary + FROM + WHERE .age > 17 + AND .salary > 35000; + """, + ) + assert [RuleGeneratorV2.deparse(t) for t in RuleGeneratorV2.subtrees(result.pattern_ast, result.rewrite_ast)] == ["."] + + +def test_subtrees_5(): + result = RuleParserV2.parse( + """ + SELECT . AS admin_pe1_4_, . AS descript2_4_, . AS is_frien3_4_, . AS name4_4_, . AS permissi5_4_ + FROM + INNER JOIN ON . = . + INNER JOIN ON . = . + WHERE . = + AND . = + ORDER BY . ASC + LIMIT + """, + """ + SELECT . AS admin_pe1_4_, . AS descript2_4_, . AS is_frien3_4_, . AS name4_4_, . AS permissi5_4_ + FROM + INNER JOIN ON . = . + WHERE . = + AND . = + ORDER BY . ASC + LIMIT + """, + ) + actual = set(RuleGeneratorV2.deparse(t) for t in RuleGeneratorV2.subtrees(result.pattern_ast, result.rewrite_ast)) + assert actual == { + ". = ", + ". = .", + ".", + ".", + ".", + ".", + ".", + } + + +def test_variablize_subtree_1(): + rule = _build_rule( + """ + select ., .age, .salary + from , + where . = . + and .age > 17 + and .salary > 35000 + """, + """ + SELECT ., .age, .salary + FROM + WHERE .age > 17 + AND .salary > 35000 + """, + ) + subtree = RuleGeneratorV2.subtrees(rule["pattern_ast"], rule["rewrite_ast"])[0] + out = RuleGeneratorV2.variablize_subtree(rule, subtree) + assert _norm_sql(out["pattern"]) == _norm_sql( + "SELECT , .age, .salary FROM , WHERE . = . AND .age > 17 AND .salary > 35000" + ) + assert _norm_sql(out["rewrite"]) == _norm_sql( + "SELECT , .age, .salary FROM WHERE .age > 17 AND .salary > 35000" + ) + + +def test_variablize_subtrees_1(): + rule = _build_rule( + """ + SELECT . AS admin_pe1_4_, . AS descript2_4_, . AS is_frien3_4_, . AS name4_4_, . AS permissi5_4_ + FROM + INNER JOIN ON . = . + INNER JOIN ON . = . + WHERE . = + AND . = + ORDER BY . ASC + LIMIT + """, + """ + SELECT . AS admin_pe1_4_, . AS descript2_4_, . AS is_frien3_4_, . AS name4_4_, . AS permissi5_4_ + FROM + INNER JOIN ON . = . + WHERE . = + AND . = + ORDER BY . ASC + LIMIT + """, + ) + children = RuleGeneratorV2.variablize_subtrees(rule) + assert len(children) == 7 + + +def test_variable_lists_1(): + result = RuleParserV2.parse( + """ + SELECT , , , , + FROM + INNER JOIN ON + INNER JOIN ON . = . + WHERE + AND . = + ORDER BY . ASC + LIMIT + """, + """ + SELECT , , , , + FROM + INNER JOIN ON + WHERE + AND . = + ORDER BY . ASC + LIMIT + """, + ) + variable_lists = RuleGeneratorV2.variable_lists(result.pattern_ast, result.rewrite_ast) + normalized = {",".join(sorted(v)) for v in variable_lists} + assert "x14,x15,x16,x17,x18" in normalized + assert "x12" in normalized + assert "x11" in normalized + + +def test_variable_lists_2(): + result = RuleParserV2.parse( + """ + SELECT + FROM + INNER JOIN ON + INNER JOIN ON . = . + WHERE + AND . = + """, + """ + SELECT + FROM + INNER JOIN ON + WHERE . = + AND + """, + ) + variable_lists = RuleGeneratorV2.variable_lists(result.pattern_ast, result.rewrite_ast) + normalized = {",".join(sorted(v)) for v in variable_lists} + assert "x11" in normalized + assert "x8" in normalized + + +def test_variable_lists_3(): + result = RuleParserV2.parse( + """ + SELECT , , , , , + FROM + LEFT OUTER JOIN ON + LEFT OUTER JOIN ON . = . + WHERE . = + """, + """ + SELECT , , , , , + FROM + LEFT OUTER JOIN ON + WHERE . = + """, + ) + variable_lists = RuleGeneratorV2.variable_lists(result.pattern_ast, result.rewrite_ast) + normalized = {tuple(sorted(v)) for v in variable_lists} + assert ("x13",) in normalized + assert ("x14", "x15", "x16", "x17", "x18", "x19") in normalized + + +def test_merge_variable_list_1(): + rule = _build_rule( + """ + SELECT , , , , + FROM + INNER JOIN ON + INNER JOIN ON . = . + WHERE + AND . = + ORDER BY . ASC + LIMIT + """, + """ + SELECT , , , , + FROM + INNER JOIN ON + WHERE + AND . = + ORDER BY . ASC + LIMIT + """, + ) + out = RuleGeneratorV2.merge_variable_list(rule, ["x18", "x17", "x16", "x15", "x14"]) + assert "SELECT <>" in out["pattern"] + assert "SELECT <>" in out["rewrite"] + + +def test_merge_variable_list_2(): + rule = _build_rule( + """ + SELECT , , , , + FROM + INNER JOIN ON + INNER JOIN ON . = . + WHERE + AND . = + ORDER BY . ASC + LIMIT + """, + """ + SELECT , , , , + FROM + INNER JOIN ON + WHERE + AND . = + ORDER BY . ASC + LIMIT + """, + ) + out = RuleGeneratorV2.merge_variable_list(rule, ["x11"]) + assert "LIMIT <>" in out["pattern"] + assert "LIMIT <>" in out["rewrite"] + + +def test_merge_variable_list_descends_into_select_expressions(): + rule = _build_rule( + """ + SELECT COALESCE(., ) + FROM + LEFT JOIN ON + """, + """ + SELECT COALESCE( + (SELECT . + FROM + WHERE AND . = 1 + LIMIT 1), + + ) + FROM + """, + ) + out = RuleGeneratorV2.merge_variable_list(rule, ["x9"]) + assert "LEFT JOIN ON <>" in out["pattern"] + assert "WHERE <> AND . = 1" in out["rewrite"] + + +def test_branches_1(): + result = RuleParserV2.parse( + "SELECT <> FROM WHERE CAST(created_at AS DATE) = TIMESTAMP '2016-10-01 00:00:00.000'", + "SELECT <> FROM WHERE created_at = TIMESTAMP '2016-10-01 00:00:00.000'", + ) + branches = RuleGeneratorV2.branches(result.pattern_ast, result.rewrite_ast) + assert {"key": "select", "value": "set_variable"} in branches + + +def test_branches_2(): + result = RuleParserV2.parse( + "FROM WHERE CAST(created_at AS DATE) = TIMESTAMP '2016-10-01 00:00:00.000'", + "FROM WHERE created_at = TIMESTAMP '2016-10-01 00:00:00.000'", + ) + branches = RuleGeneratorV2.branches(result.pattern_ast, result.rewrite_ast) + assert {"key": "from", "value": "table_sources"} in branches + + +def test_branches_3(): + result = RuleParserV2.parse( + "WHERE CAST(created_at AS DATE) = TIMESTAMP '2016-10-01 00:00:00.000'", + "WHERE created_at = TIMESTAMP '2016-10-01 00:00:00.000'", + ) + branches = RuleGeneratorV2.branches(result.pattern_ast, result.rewrite_ast) + assert {"key": "where", "value": None} in branches + + +def test_branches_4(): + result = RuleParserV2.parse( + "CAST(created_at AS DATE) = TIMESTAMP ''", + "created_at = TIMESTAMP ''", + ) + branches = RuleGeneratorV2.branches(result.pattern_ast, result.rewrite_ast) + actual = {(b["key"], RuleGeneratorV2.deparse(b["value"])) for b in branches} + assert actual == {("eq_rhs", "TIMESTAMP('')")} + + +def test_branches_5(): + result = RuleParserV2.parse( + "SELECT * FROM WHERE CAST(created_at AS DATE) = TIMESTAMP '2016-10-01 00:00:00.000'", + "SELECT * FROM WHERE created_at = TIMESTAMP '2016-10-01 00:00:00.000'", + ) + branches = RuleGeneratorV2.branches(result.pattern_ast, result.rewrite_ast) + actual = {(b["key"], b["value"]) for b in branches if isinstance(b["value"], str)} + assert ("select", "all_columns") in actual + + +def test_drop_branch_1(): + rule = _build_rule( + "SELECT <> FROM WHERE CAST(created_at AS DATE) = TIMESTAMP '2016-10-01 00:00:00.000'", + "SELECT <> FROM WHERE created_at = TIMESTAMP '2016-10-01 00:00:00.000'", + ) + out = RuleGeneratorV2.drop_branch(rule, {"key": "select", "value": "set_variable"}) + parsed = RuleParserV2.parse(out["pattern"], out["rewrite"]) + assert isinstance(parsed.pattern_ast, QueryNode) + assert isinstance(parsed.rewrite_ast, QueryNode) + assert _has_clause(parsed.pattern_ast, NodeType.SELECT) is False + assert _has_clause(parsed.rewrite_ast, NodeType.SELECT) is False + assert _has_clause(parsed.pattern_ast, NodeType.FROM) is True + assert _has_clause(parsed.rewrite_ast, NodeType.FROM) is True + assert _has_clause(parsed.pattern_ast, NodeType.WHERE) is True + assert _has_clause(parsed.rewrite_ast, NodeType.WHERE) is True + + +def test_drop_branch_2(): + rule = _build_rule( + "FROM WHERE CAST(created_at AS DATE) = TIMESTAMP '2016-10-01 00:00:00.000'", + "FROM WHERE created_at = TIMESTAMP '2016-10-01 00:00:00.000'", + ) + out = RuleGeneratorV2.drop_branch(rule, {"key": "from", "value": "table_sources"}) + parsed = RuleParserV2.parse(out["pattern"], out["rewrite"]) + assert isinstance(parsed.pattern_ast, QueryNode) + assert isinstance(parsed.rewrite_ast, QueryNode) + assert _has_clause(parsed.pattern_ast, NodeType.SELECT) is False + assert _has_clause(parsed.rewrite_ast, NodeType.SELECT) is False + assert _has_clause(parsed.pattern_ast, NodeType.FROM) is False + assert _has_clause(parsed.rewrite_ast, NodeType.FROM) is False + assert _has_clause(parsed.pattern_ast, NodeType.WHERE) is True + assert _has_clause(parsed.rewrite_ast, NodeType.WHERE) is True + + +def test_drop_branch_3(): + rule = _build_rule( + "WHERE CAST(created_at AS DATE) = TIMESTAMP '2016-10-01 00:00:00.000'", + "WHERE created_at = TIMESTAMP '2016-10-01 00:00:00.000'", + ) + out = RuleGeneratorV2.drop_branch(rule, {"key": "where", "value": None}) + parsed = RuleParserV2.parse(out["pattern"], out["rewrite"]) + assert not isinstance(parsed.pattern_ast, QueryNode) + assert not isinstance(parsed.rewrite_ast, QueryNode) + + +def test_drop_branch_4(): + rule = _build_rule( + "CAST(created_at AS DATE) = TIMESTAMP ''", + "created_at = TIMESTAMP ''", + ) + branch = RuleGeneratorV2.branches(rule["pattern_ast"], rule["rewrite_ast"])[0] + out = RuleGeneratorV2.drop_branch(rule, branch) + assert _norm_sql(out["pattern"]) == _norm_sql("CAST(created_at AS DATE)") + assert _norm_sql(out["rewrite"]) == _norm_sql("created_at") + + +def test_fingerprint_normalizes_numbered_placeholders(): + rule = _build_rule("SELECT , FROM WHERE <>", "SELECT FROM WHERE <>") + fp = RuleGeneratorV2.fingerPrint(rule) + assert "" in fp + assert "<>" in fp + assert "" not in fp + assert "<>" not in fp + + +def test_fingerprint_same_for_renamed_variables(): + rule1 = _build_rule("CAST( AS DATE)", "") + rule2 = _build_rule("CAST( AS DATE)", "") + assert RuleGeneratorV2.fingerPrint(rule1) == RuleGeneratorV2.fingerPrint(rule2) + + +def test_unify_variable_names_1(): + q0 = "FROM <> INNER JOIN ON <>. = ." + q1 = "FROM " + a, b = RuleGeneratorV2.unify_variable_names(q0, q1) + assert a == "FROM <> INNER JOIN ON <>. = ." + assert b == "FROM " + + +def test_unify_variable_names_2(): + q0 = " <>" + q1 = "" + a, b = RuleGeneratorV2.unify_variable_names(q0, q1) + assert a == " <>" + assert b == "" + + +def test_unify_variable_names_3(): + q0 = " <> " + q1 = " <> " + a, b = RuleGeneratorV2.unify_variable_names(q0, q1) + assert a == " <> " + assert b == " <> " + + +def test_number_of_variables(): + rule = _build_rule("SELECT , <> FROM ", "SELECT , <> FROM ") + assert RuleGeneratorV2.numberOfVariables(rule) == 3 + + +def test_generate_general_rule_1(): + rule = RuleGeneratorV2.generate_general_rule("SELECT CAST(created_at AS DATE)", "SELECT created_at") + assert rule["pattern"] == "CAST( AS DATE)" + assert rule["rewrite"] == "" + + +def test_generate_general_rule_2(): + rule = RuleGeneratorV2.generate_general_rule( + "SELECT STRPOS(LOWER(text), 'iphone') > 0", + "SELECT ILIKE(text, '%iphone%')", + ) + assert rule["pattern"] == "STRPOS(LOWER(), '') > 0" + assert rule["rewrite"] == " ILIKE '%%'" + + +def test_generate_general_rule_8(): + q0 = "SELECT * FROM t WHERE CAST(created_at AS DATE) = TIMESTAMP '2016-10-01 00:00:00.000'" + q1 = "SELECT * FROM t WHERE created_at = TIMESTAMP '2016-10-01 00:00:00.000'" + _assert_matches_rule(q0, q1, "remove_cast_date") + + +def test_generate_general_rule_3(): + q0 = """ + select e1.name, e1.age, e2.salary + from employee e1, + employee e2 + where e1.id = e2.id + and e1.age > 17 + and e2.salary > 35000 + """ + q1 = """ + SELECT e1.name, e1.age, e1.salary + FROM employee e1 + WHERE e1.age > 17 + AND e1.salary > 35000 + """ + _assert_matches_expected( + q0, + q1, + """ + SELECT <>, . + FROM , + WHERE . = . + AND <> + AND . > + """, + """ + SELECT <>, . + FROM + WHERE <> + AND . > + """, + ) + + +def test_generate_general_rule_4(): + q0 = """ + SELECT * + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + INNER JOIN blc_admin_role adminrolei2_ + ON allroles1_.admin_role_id = adminrolei2_.admin_role_id + WHERE adminrolei2_.admin_role_id = 1 + """ + q1 = """ + SELECT * + FROM blc_admin_permission AS adminpermi0_ + INNER JOIN blc_admin_role_permission_xref AS allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + WHERE allroles1_.admin_role_id = 1 + """ + _assert_matches_expected( + q0, + q1, + """ + FROM + INNER JOIN + ON <> + INNER JOIN + ON . = . + WHERE . = + """, + """ + FROM + INNER JOIN + ON <> + WHERE . = + """, + ) + + +def test_generate_general_rule_5(): + q0 = """ + SELECT adminpermi0_.admin_permission_id AS admin_pe1_4_, + adminpermi0_.description AS descript2_4_, + adminpermi0_.is_friendly AS is_frien3_4_, + adminpermi0_.name AS name4_4_, + adminpermi0_.permission_type AS permissi5_4_ + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + INNER JOIN blc_admin_role adminrolei2_ + ON allroles1_.admin_role_id = adminrolei2_.admin_role_id + WHERE adminpermi0_.is_friendly = 1 + AND adminrolei2_.admin_role_id = 1 + ORDER BY adminpermi0_.description ASC + LIMIT 50 + """ + q1 = """ + SELECT adminpermi0_.admin_permission_id AS admin_pe1_4_, + adminpermi0_.description AS descript2_4_, + adminpermi0_.is_friendly AS is_frien3_4_, + adminpermi0_.name AS name4_4_, + adminpermi0_.permission_type AS permissi5_4_ + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + WHERE adminpermi0_.is_friendly = 1 + AND allroles1_.admin_role_id = 1 + ORDER BY adminpermi0_.description ASC + LIMIT 50 + """ + _assert_matches_expected( + q0, + q1, + """ + FROM + INNER JOIN + ON <> + INNER JOIN + ON . = . + WHERE <> + AND . = + """, + """ + FROM + INNER JOIN + ON <> + WHERE <> + AND . = + """, + ) + + +def test_generate_general_rule_6(): + q0 = """ + SELECT Count(adminpermi0_.admin_permission_id) AS col_0_0_ + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + INNER JOIN blc_admin_role adminrolei2_ + ON allroles1_.admin_role_id = adminrolei2_.admin_role_id + WHERE adminpermi0_.is_friendly = 1 + AND adminrolei2_.admin_role_id = 1 + """ + q1 = """ + SELECT Count(adminpermi0_.admin_permission_id) AS col_0_0_ + FROM blc_admin_permission AS adminpermi0_ + INNER JOIN blc_admin_role_permission_xref AS allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + WHERE allroles1_.admin_role_id = 1 + AND adminpermi0_.is_friendly = 1 + """ + _assert_matches_expected( + q0, + q1, + """ + FROM + INNER JOIN + ON <> + INNER JOIN + ON . = . + WHERE <> + AND . = + """, + """ + FROM + INNER JOIN + ON <> + WHERE . = + AND <> + """, + ) + + +def test_generate_general_rule_7(): + q0 = """ + SELECT o_auth_applications.id + FROM o_auth_applications + INNER JOIN authorizations + ON o_auth_applications.id = authorizations.o_auth_application_id + WHERE authorizations.user_id = 1465 + """ + q1 = """ + SELECT authorizations.o_auth_application_id + FROM authorizations AS authorizations + WHERE authorizations.user_id = 1465 + """ + _assert_matches_expected( + q0, + q1, + """ + SELECT . + FROM + INNER JOIN + ON . = . + """, + """ + SELECT . + FROM + """, + ) + + +def test_generate_general_rule_9(): + q0 = """ + SELECT SUM(1), CAST(state_name AS TEXT) + FROM tweets + WHERE CAST(DATE_TRUNC('QUARTER', CAST(created_at AS DATE)) AS DATE) IN + ((TIMESTAMP '2016-10-01 00:00:00.000'), + (TIMESTAMP '2017-01-01 00:00:00.000'), + (TIMESTAMP '2017-04-01 00:00:00.000')) + AND (STRPOS(LOWER(text), 'iphone') > 0) + GROUP BY 2 + """ + q1 = """ + SELECT SUM(1), CAST(state_name AS TEXT) + FROM tweets + WHERE CAST(DATE_TRUNC('QUARTER', CAST(created_at AS DATE)) AS DATE) IN + ((TIMESTAMP '2016-10-01 00:00:00.000'), + (TIMESTAMP '2017-01-01 00:00:00.000'), + (TIMESTAMP '2017-04-01 00:00:00.000')) + AND text ILIKE '%iphone%' + GROUP BY 2 + """ + _assert_matches_expected( + q0, + q1, + "STRPOS(LOWER(), '') > 0", + " ILIKE '%%'", + ) + + +def test_generate_general_rule_10(): + q0 = """ + select * + from employee + where workdept in + (select deptno from department where deptname = 'OPERATIONS') + """ + q1 = """ + select distinct * + from employee, department + where employee.workdept = department.deptno + and department.deptname = 'OPERATIONS' + """ + + expected_pattern = """ + SELECT + FROM + WHERE IN (SELECT + FROM + WHERE = '') + """ + expected_rewrite = """ + SELECT DISTINCT + FROM , + WHERE . = . + AND . = '' + """ + _assert_matches_expected(q0, q1, expected_pattern, expected_rewrite) + + +def test_generate_general_rule_11(): + q0 = """ + SELECT Count(*) + FROM (SELECT 1 AS one + FROM group_histories + WHERE group_histories.group_id = 2578 + AND group_histories.action = 2 + ORDER BY group_histories.created_at DESC + LIMIT 25 offset 0) subquery_for_count + """ + q1 = """ + SELECT Count(*) + FROM (SELECT 1 AS one + FROM group_histories + WHERE group_histories.group_id = 2578 + AND group_histories.action = 2 + LIMIT 25 offset 0) AS subquery_for_count + """ + _assert_matches_expected( + q0, + q1, + """ + FROM ORDER BY . DESC + """, + """ + FROM + """, + ) + + +def test_generate_general_rule_12(): + q0 = "SELECT student.ids from student WHERE student.id = 100 AND student.abc = 100" + q1 = "SELECT student.id from student WHERE student.id = 100" + _assert_matches_expected( + q0, + q1, + "SELECT . FROM WHERE <> AND . = ", + "SELECT . FROM WHERE <>", + ) + + +def test_generate_general_rule_13(): + q0 = """ + SELECT COUNT(adminpermi0_.admin_permission_id) AS col_0_0_ + FROM blc_admin_permission adminpermi0_ + INNER JOIN blc_admin_role_permission_xref allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + INNER JOIN blc_admin_role adminrolei2_ + ON allroles1_.admin_role_id = adminrolei2_.admin_role_id + WHERE adminpermi0_.is_friendly = 1 AND adminrolei2_.admin_role_id = 1 + """ + q1 = """ + SELECT COUNT(adminpermi0_.admin_permission_id) AS col_0_0_ + FROM blc_admin_permission AS adminpermi0_ + INNER JOIN blc_admin_role_permission_xref AS allroles1_ + ON adminpermi0_.admin_permission_id = allroles1_.admin_permission_id + WHERE allroles1_.admin_role_id = 1 AND adminpermi0_.is_friendly = 1 + """ + _assert_matches_expected( + q0, + q1, + """ + FROM INNER JOIN ON <> INNER JOIN ON . = . + WHERE <> AND . = + """, + """ + FROM INNER JOIN ON <> WHERE . = AND <> + """, + ) + + +def test_generate_general_rule_14(): + q0 = """select distinct c.customer_id from table1 c join table2 l on c.customer_id = l.customer_id join table3 cal on c.customer_id = cal.customer_id WHERE (l.customer_group_id = 'loyalty' and c.loyalty_number = '123456789') or (cal.account_id = '123456789' and cal.account_type = 'loyalty')""" + q1 = """SELECT customer_id FROM table1 c JOIN table2 l USING (customer_id) JOIN table3 cal USING (customer_id) WHERE l.customer_group_id = 'loyalty' AND c.loyalty_number = '123456789' UNION SELECT customer_id FROM table1 c JOIN table2 l USING (customer_id) JOIN table3 cal USING (customer_id) WHERE cal.account_id = '123456789' AND cal.account_type = 'loyalty'""" + _exp_rw = ( + "SELECT FROM JOIN USING JOIN USING WHERE \n" + "UNION\n" + "SELECT FROM JOIN USING JOIN USING WHERE " + ) + _assert_matches_expected( + q0, + q1, + "SELECT DISTINCT . FROM JOIN ON . = . JOIN ON . = . WHERE OR ", + _exp_rw, + ) + + +def test_generate_general_rule_15(): + q0 = "select * from A a left join B b on a.id = b.cid where b.cl1 = 's1' or b.cl1 ='s2' or b.cl1 ='s3'" + q1 = "select * from A a left join B b on a.id = b.cid where b.cl1 in ('s1','s2','s3')" + _assert_matches_rule(q0, q1, "spreadsheet_id_7") + + +def test_generate_general_rule_16(): + q0 = """SELECT historicoestatusrequisicion_id, requisicion_id, estatusrequisicion_id, comentario, fecha_estatus, usuario_id FROM historicoestatusrequisicion hist1 WHERE requisicion_id IN (SELECT requisicion_id FROM historicoestatusrequisicion hist2 WHERE usuario_id = 27 AND estatusrequisicion_id = 1) ORDER BY requisicion_id, estatusrequisicion_id""" + q1 = """SELECT hist1.historicoestatusrequisicion_id, hist1.requisicion_id, hist1.estatusrequisicion_id, hist1.comentario, hist1.fecha_estatus, hist1.usuario_id FROM historicoestatusrequisicion hist1 JOIN historicoestatusrequisicion hist2 ON hist2.requisicion_id = hist1.requisicion_id WHERE hist2.usuario_id = 27 AND hist2.estatusrequisicion_id = 1 ORDER BY hist1.requisicion_id, hist1.estatusrequisicion_id""" + _assert_matches_rule(q0, q1, "spreadsheet_id_11") + + +def test_generate_general_rule_17(): + q0 = """select wpis_id from spoleczniak_oznaczone where etykieta_id in( select tag_id from spoleczniak_subskrypcje where postac_id = 376476 )""" + q1 = """select spoleczniak_oznaczone.wpis_id from spoleczniak_oznaczone inner join spoleczniak_subskrypcje on spoleczniak_subskrypcje.tag_id = spoleczniak_oznaczone.etykieta_id where spoleczniak_subskrypcje.postac_id = 376476""" + _assert_matches_expected( + q0, + q1, + "SELECT FROM WHERE IN (SELECT FROM WHERE = )", + "SELECT . FROM INNER JOIN ON . = . WHERE . = ", + ) + + +def test_generate_general_rule_18(): + q0 = "SELECT EMP.EMPNO FROM EMP WHERE EMP.EMPNO > 10 AND EMP.EMPNO <= 10" + q1 = "SELECT EMPNO FROM EMP WHERE FALSE" + _assert_matches_expected( + q0, + q1, + "SELECT . FROM WHERE . > AND . <= ", + "SELECT FROM WHERE False", + ) + + +def test_generate_general_rule_19(): + q0 = "SELECT max(id) FROM Emp" + q1 = "SELECT max(DISTINCT id) FROM Emp" + rule = RuleGeneratorV2.generate_general_rule(q0, q1) + q0_rule, q1_rule = RuleGeneratorV2.unify_variable_names(rule["pattern"], rule["rewrite"]) + assert q0_rule == "MAX()" + assert q1_rule == "MAX(DISTINCT )" + + +def test_generate_general_rule_20(): + q0 = """ + SELECT * + FROM accounts + WHERE LOWER(accounts.firstname) = LOWER('Sam') + AND accounts.id IN ( + SELECT addresses.account_id + FROM addresses + WHERE LOWER(addresses.name) = LOWER('Street1') + ) + AND accounts.id IN ( + SELECT alternate_ids.account_id + FROM alternate_ids + WHERE alternate_ids.alternate_id_glbl = '5' + ) + """ + q1 = """ + SELECT * + FROM accounts + JOIN addresses ON accounts.id = addresses.account_id + JOIN alternate_ids ON accounts.id = alternate_ids.account_id + WHERE LOWER(accounts.firstname) = LOWER('Sam') + AND LOWER(addresses.name) = LOWER('Street1') + AND alternate_ids.alternate_id_glbl = '5' + """ + _assert_matches_rule(q0, q1, "subquery_to_joins") + + +def test_generate_general_rule_21(): + q0 = """ + SELECT product.name, category.description, category.category_id + FROM product NATURAL JOIN category + WHERE product.price > 100 + AND product.category_id = 4 + """ + q1 = """ + SELECT product.name, category.description, category.category_id + FROM product INNER JOIN category ON product.category_id = category.category_id + WHERE product.price > 100 + """ + _assert_matches_expected( + q0, + q1, + "FROM NATURAL JOIN () WHERE <> AND . = 4", + "FROM INNER JOIN ON . = . WHERE <>", + ) + + +def test_generate_general_rule_22(): + q0 = """ + SELECT + t1.CPF, + DATE(t1.data), + CASE WHEN SUM(CASE WHEN t1.login_ok = true THEN 1 ELSE 0 END) >= 1 + THEN true + ELSE false + END + FROM db_risco.site_rn_login AS t1 + GROUP BY t1.CPF, DATE(t1.data) + """ + q1 = """ + SELECT + t1.CPF, + t1.data + FROM ( + SELECT CPF, DATE(data) + FROM db_risco.site_rn_login + WHERE login_ok = true + ) t1 + GROUP BY t1.CPF, t1.data + """ + _assert_matches_expected( + q0, + q1, + "SELECT <>, DATE(.), CASE WHEN SUM(CASE WHEN . = THEN ELSE END) >= THEN ELSE END FROM GROUP BY <>, DATE(.)", + "SELECT <>, . FROM (SELECT , DATE() FROM WHERE = ) AS t1 GROUP BY <>, .", + ) + + +def test_recommend_simple_rules_1(): + examples = [ + { + "q0": "SELECT * FROM employee WHERE workdept IN (SELECT deptno FROM department WHERE deptname = 'OPERATIONS')", + "q1": "SELECT DISTINCT * FROM employee, department where employee.workdept = department.deptno AND department.deptname = 'OPERATIONS'", + } + ] + rules = RuleGeneratorV2.recommend_simple_rules(examples) + assert _norm_sql(rules[0]["pattern"]) == _norm_sql( + "SELECT * FROM WHERE workdept IN (SELECT deptno FROM department WHERE deptname = 'OPERATIONS')" + ) + assert _norm_sql(rules[0]["rewrite"]) == _norm_sql( + "SELECT DISTINCT * FROM , department WHERE .workdept = department.deptno AND department.deptname = 'OPERATIONS'" + ) + + +def test_recommend_simple_rules_2(): + examples = [ + { + "q0": "SELECT Count(*) FROM (SELECT 1 AS one FROM group_histories WHERE group_histories.group_id = 2578 AND group_histories.action = 2 ORDER BY group_histories.created_at DESC LIMIT 25 offset 0) subquery_for_count", + "q1": "SELECT Count(*) FROM (SELECT 1 AS one FROM group_histories WHERE group_histories.group_id = 2578 AND group_histories.action = 2 LIMIT 25 offset 0) AS subquery_for_count", + }, + { + "q0": "SELECT Count(*) FROM (SELECT 1 AS one FROM gh WHERE gh.group_id = 2578 AND gh.action = 2 ORDER BY gh.created_at DESC LIMIT 25 offset 0) subquery_for_count", + "q1": "SELECT Count(*) FROM (SELECT 1 AS one FROM gh WHERE gh.group_id = 2578 AND gh.action = 2 LIMIT 25 offset 0) AS subquery_for_count", + }, + ] + rules = RuleGeneratorV2.recommend_simple_rules(examples) + assert _norm_sql(rules[0]["pattern"]) == _norm_sql( + "SELECT COUNT(*) FROM (SELECT 1 AS one FROM WHERE .group_id = 2578 AND .action = 2 ORDER BY .created_at DESC LIMIT 25 OFFSET 0) AS subquery_for_count" + ) + assert _norm_sql(rules[0]["rewrite"]) == _norm_sql( + "SELECT COUNT(*) FROM (SELECT 1 AS one FROM WHERE .group_id = 2578 AND .action = 2 LIMIT 25 OFFSET 0) AS subquery_for_count" + ) + + +def test_recommend_simple_rules_3(): + examples = [ + {"q0": "SELECT CAST(create_at as DATE)", "q1": "SELECT create_at"}, + {"q0": "SELECT CAST(create_at1 as DATE)", "q1": "SELECT create_at1"}, + {"q0": "SELECT STRPOS(LOWER(text), 'iphone') > 0", "q1": "SELECT ILIKE(text, '%iphone%')"}, + {"q0": "SELECT STRPOS(LOWER(text1), 'iphone') > 0", "q1": "SELECT ILIKE(text1, '%iphone%')"}, + {"q0": "SELECT STRPOS(LOWER(text), 'iphone1') > 0", "q1": "SELECT ILIKE(text, '%iphone1%')"}, + ] + rules = RuleGeneratorV2.recommend_simple_rules(examples) + assert _norm_sql(rules[0]["pattern"]) == _norm_sql("SELECT CAST( AS DATE)") + assert _norm_sql(rules[0]["rewrite"]) == _norm_sql("SELECT ") + assert _norm_sql(rules[1]["pattern"]) == _norm_sql("SELECT STRPOS(LOWER(text), '') > 0") + assert _norm_sql(rules[1]["rewrite"]) == _norm_sql("SELECT text ILIKE '%%'") + + +def test_recommend_simple_rules_4(): + examples = [ + { + "q0": "SELECT e1.name, e1.age, e2.salary FROM employee e1, employee e2 WHERE e1.id = e2.id AND e1.age > 17 AND e2.salary > 35000", + "q1": "SELECT e1.name, e1.age, e1.salary FROM employee e1 WHERE e1.age > 17 AND e1.salary > 35000", + }, + { + "q0": "SELECT e1.name, e1.ages, e2.salary FROM employee e1, employee e2 WHERE e1.id = e2.id AND e1.ages > 17 AND e2.salary > 35000", + "q1": "SELECT e1.name, e1.ages, e1.salary FROM employee e1 WHERE e1.ages > 17 AND e1.salary > 35000", + }, + { + "q0": "SELECT * FROM t WHERE CAST(created_at AS DATE) = TIMESTAMP '2016-10-01 00:00:00.000'", + "q1": "SELECT * FROM t WHERE created_at = TIMESTAMP '2016-10-01 00:00:00.000'", + }, + { + "q0": "SELECT s.ids from s WHERE s.x = 100 AND s.abc = 100", + "q1": "SELECT s.x from s WHERE s.x = 100", + }, + { + "q0": "SELECT student.ids from student WHERE student.id = 100 AND student.abc = 100", + "q1": "SELECT student.id from student WHERE student.id = 100", + }, + ] + rules = RuleGeneratorV2.recommend_simple_rules(examples) + assert _norm_sql(rules[0]["pattern"]) == _norm_sql( + "SELECT e1.name, e1., e2.salary FROM employee AS e1, employee AS e2 WHERE e1.id = e2.id AND e1. > 17 AND e2.salary > 35000" + ) + assert _norm_sql(rules[0]["rewrite"]) == _norm_sql( + "SELECT e1.name, e1., e1.salary FROM employee AS e1 WHERE e1. > 17 AND e1.salary > 35000" + ) + assert _norm_sql(rules[1]["pattern"]) == _norm_sql( + "SELECT * FROM WHERE CAST(created_at AS DATE) = TIMESTAMP('2016-10-01 00:00:00.000')" + ) + assert _norm_sql(rules[1]["rewrite"]) == _norm_sql( + "SELECT * FROM WHERE created_at = TIMESTAMP('2016-10-01 00:00:00.000')" + ) + assert _norm_sql(rules[2]["pattern"]) == _norm_sql( + "SELECT .ids FROM WHERE . = 100 AND .abc = 100" + ) + assert _norm_sql(rules[2]["rewrite"]) == _norm_sql( + "SELECT . FROM WHERE . = 100" + ) + + +def test_parse_validator_1(): + success1, _err1, _idx1 = RuleGeneratorV2.parse_validate_single("CAST( AS DATE)") + success2, _err2, _idx2 = RuleGeneratorV2.parse_validate_single("") + success3, _err3, _idx3 = RuleGeneratorV2.parse_validate("CAST( AS DATE)", "") + assert success1 is True + assert success2 is True + assert success3 is True + + +def test_parse_validator_2(): + success, errormessage, index = RuleGeneratorV2.parse_validate("CAST( AS DATE)", "") + assert success is False + assert index == 0 + assert "not in first rule" in errormessage + + +def test_parse_validator_3(): + success1, errormessage1, index1 = RuleGeneratorV2.parse_validate_single("CAST( AS DATEE)") + success2, errormessage2, index2 = RuleGeneratorV2.parse_validate("CAST( AS DATEE)", "") + assert success1 is False + assert index1 == 13 + assert "DATEE" in errormessage1 + assert success2 is False + assert index2 == 13 + assert "DATEE" in errormessage2 + + +def test_parse_validator_4(): + success1, errormessage1, index1 = RuleGeneratorV2.parse_validate_single("CA NT( AS DATE)") + success2, errormessage2, index2 = RuleGeneratorV2.parse_validate("CA NT( AS DATE)", "") + assert success1 is False + assert index1 == 3 + assert "NT" in errormessage1 + assert success2 is False + assert index2 == 3 + assert "NT" in errormessage2 + + +def test_parse_validator_5(): + pattern = """SELECT + FROM + WHERE > 10 + AND <= 10 + """ + rewrite = """SELECT + FROM + WHERE FALSE + """ + success1, _err1, index1 = RuleGeneratorV2.parse_validate_single(pattern) + success2, _err2, index2 = RuleGeneratorV2.parse_validate_single(rewrite) + success3, _err3, index3 = RuleGeneratorV2.parse_validate(pattern, rewrite) + assert success1 is True and index1 == 0 + assert success2 is True and index2 == 0 + assert success3 is True and index3 == 0 + + +def test_parse_validator_6(): + pattern = """FRUM + WHERE > 10 + AND <= 10 + """ + rewrite = """FROM + WHERE FALSE + """ + success1, errormessage1, index1 = RuleGeneratorV2.parse_validate_single(pattern) + success2, errormessage2, index2 = RuleGeneratorV2.parse_validate(pattern, rewrite) + assert success1 is False and index1 == 0 and "spelling" in errormessage1 + assert success2 is False and index2 == 0 and "spelling" in errormessage2 + + +def test_parse_validator_7(): + pattern = """WHURE > 10 + AND <= 10 + """ + rewrite = """WHERE FALSE""" + success1, errormessage1, index1 = RuleGeneratorV2.parse_validate_single(pattern) + success2, errormessage2, index2 = RuleGeneratorV2.parse_validate(pattern, rewrite) + assert success1 is False and index1 == 0 and "spelling" in errormessage1 + assert success2 is False and index2 == 0 and "spelling" in errormessage2 + + +def test_parse_validator_8(): + pattern = """SELUCT + FROM + WHERE >> 10 + AND <= 10 + """ + rewrite = """SELECT + FROM + WHERE FALSE + """ + success1, errormessage1, index1 = RuleGeneratorV2.parse_validate_single(pattern) + success2, errormessage2, index2 = RuleGeneratorV2.parse_validate(pattern, rewrite) + assert success1 is False and index1 == 0 and "spelling" in errormessage1 + assert success2 is False and index2 == 0 and "spelling" in errormessage2 + + +def test_parse_validator_9(): + pattern = """FRUM , EN END""" + rewrite = """FROM """ + success1, errormessage1, index1 = RuleGeneratorV2.parse_validate_single(pattern) + success2, errormessage2, index2 = RuleGeneratorV2.parse_validate(pattern, rewrite) + assert success1 is False and index1 == 0 and "spelling" in errormessage1 + assert success2 is False and index2 == 0 and "spelling" in errormessage2 + + +def test_parse_validator_10(): + pattern = """WHERE > 11 5 10 + AND <= 11 + """ + rewrite = """WHERE FALSE""" + success1, errormessage1, index1 = RuleGeneratorV2.parse_validate_single(pattern) + success2, errormessage2, index2 = RuleGeneratorV2.parse_validate(pattern, rewrite) + assert success1 is False and index1 == 16 and "5 10" in errormessage1 + assert success2 is False and index2 == 16 and "5 10" in errormessage2 + + +def test_parse_validator_13(): + pattern = """WHERE a <4x> > 11 + AND a <= 11 + """ + rewrite = """WHERE FALSE""" + success1, errormessage1, index1 = RuleGeneratorV2.parse_validate_single(pattern) + success2, errormessage2, index2 = RuleGeneratorV2.parse_validate(pattern, rewrite) + assert success1 is False and index1 == 8 and "<4x>" in errormessage1 + assert success2 is False and index2 == 8 and "<4x>" in errormessage2 + + +def test_parse_validator_14(): + success1, _err1, _idx1 = RuleGeneratorV2.parse_validate_single("CAST( AS TEXT)") + success2, _err2, _idx2 = RuleGeneratorV2.parse_validate_single("") + success3, _err3, _idx3 = RuleGeneratorV2.parse_validate("CAST( AS TEXT)", "") + assert success1 is True + assert success2 is True + assert success3 is True + + +def test_generate_rule_graph_0(): + q0 = "CAST(created_at AS DATE)" + q1 = "created_at" + root_rule = RuleGeneratorV2.generate_rule_graph(q0, q1) + assert isinstance(root_rule, RuleV2) + children = root_rule["children"] + assert len(children) == 1 + child_rule = children[0] + assert child_rule["pattern"] == "CAST( AS DATE)" + assert child_rule["rewrite"] == "" + +def test_spreadsheet_id_1(): + q0 = """SELECT users.id + FROM users INNER JOIN addresses + ON addresses.user_id = users.id + AND addresses.type = 'VerifiedAddress' +WHERE users.deleted_at IS NULL + AND users.id in (11144,10569,21519,783,15671,21726,17787,11665,19579,12226,1324,9413,5461,20981,12906) + AND addresses.state != 'manual_verification'""" + q1 = """SELECT addresses.user_id + FROM addresses + WHERE addresses.type = 'VerifiedAddress' + AND addresses.user_id in (11144,10569,21519,783,15671,21726,17787,11665,19579,12226,1324,9413,5461,20981,12906) + AND addresses.state != 'manual_verification'""" + + _assert_matches_expected( + q0, + q1, + "SELECT . FROM INNER JOIN ON . = . AND <> " + "WHERE . IS NULL AND . IN (<>) AND <>", + "SELECT . FROM WHERE <> AND . IN (<>) AND <>", + ) + +def test_generate_spreadsheet_id_3(): + q0 = "SELECT EMPNO FROM EMP WHERE EMPNO > 10 AND EMPNO <= 10" + q1 = "SELECT EMPNO FROM EMP WHERE FALSE" + _assert_matches_expected(q0, q1, " > AND <= ", "False") + + +def test_generate_spreadsheet_id_4(): + q0 = """SELECT entities.data FROM entities WHERE + entities._id IN (SELECT index_users_email._id FROM index_users_email WHERE index_users_email.key = 'test') + OR + entities._id in (SELECT index_users_profile_name._id FROM index_users_profile_name WHERE index_users_profile_name.key = 'test')""" + q1 = """SELECT entities.data FROM entities +WHERE entities._id IN + ( SELECT index_users_email._id + FROM index_users_email + WHERE index_users_email.key = 'test' + ) +UNION +SELECT entities.data FROM entities +WHERE entities._id in + ( SELECT index_users_profile_name._id + FROM index_users_profile_name + WHERE index_users_profile_name.key = 'test' + )""" + _assert_matches_rule(q0, q1, "spreadsheet_id_4") + + +def test_generate_spreadsheet_id_6(): + q0 = """SELECT * +FROM + table_name + WHERE + (table_name.title = 1 and table_name.grade = 2) + OR + (table_name.title = 2 and table_name.debt = 2 and table_name.grade = 3) + OR + (table_name.prog = 1 and table_name.title =1 and table_name.debt = 3)""" + q1 = """SELECT * +FROM + table_name + WHERE + 1 = case + when table_name.title = 1 and table_name.grade = 2 then 1 + when table_name.title = 2 and table_name.debt = 2 and table_name.grade = 3 then 1 + when table_name.prog = 1 and table_name.title = 1 and table_name.debt = 3 then 1 + else 0 + end""" + _assert_matches_expected( + q0, + q1, + " OR OR ", + " = CASE WHEN THEN WHEN THEN WHEN THEN ELSE 0 END", + ) + + +def test_generate_spreadsheet_id_7(): + q0 = """select * from +a +left join b on a.id = b.cid +where +b.cl1 = 's1' +or +b.cl1 ='s2' +or +b.cl1 ='s3' """ + q1 = """select * from +a +left join b on a.id = b.cid +where +b.cl1 in ('s1','s2','s3')""" + _assert_matches_rule(q0, q1, "spreadsheet_id_7") + + +def test_generate_spreadsheet_id_9(): + q0 = """SELECT DISTINCT my_table.foo +FROM my_table +WHERE my_table.num = 1;""" + q1 = """SELECT my_table.foo +FROM my_table +WHERE my_table.num = 1 +GROUP BY my_table.foo;""" + _assert_matches_rule(q0, q1, "spreadsheet_id_9") + + +def test_generate_spreadsheet_id_10(): + q0 = """SELECT table1.wpis_id +FROM table1 +WHERE table1.etykieta_id IN ( + SELECT table2.tag_id + FROM table2 + WHERE table2.postac_id = 376476 + );""" + q1 = """SELECT table1.wpis_id +FROM table1 +INNER JOIN table2 on table2.tag_id = table1.etykieta_id +WHERE table2.postac_id = 376476""" + _assert_matches_rule(q0, q1, "spreadsheet_id_10") + + +def test_generate_spreadsheet_id_11(): + q0 = """SELECT historicoestatusrequisicion_id, requisicion_id, estatusrequisicion_id, + comentario, fecha_estatus, usuario_id + FROM historicoestatusrequisicion hist1 + WHERE requisicion_id IN + ( + SELECT requisicion_id FROM historicoestatusrequisicion hist2 + WHERE usuario_id = 27 AND estatusrequisicion_id = 1 + ) + ORDER BY requisicion_id, estatusrequisicion_id""" + q1 = """SELECT hist1.historicoestatusrequisicion_id, hist1.requisicion_id, hist1.estatusrequisicion_id, hist1.comentario, hist1.fecha_estatus, hist1.usuario_id + FROM historicoestatusrequisicion hist1 + JOIN historicoestatusrequisicion hist2 ON hist2.requisicion_id = hist1.requisicion_id + WHERE hist2.usuario_id = 27 AND hist2.estatusrequisicion_id = 1 + ORDER BY hist1.requisicion_id, hist1.estatusrequisicion_id""" + _assert_matches_rule(q0, q1, "spreadsheet_id_11") + + +def test_generate_spreadsheet_id_15(): + q0 = """SELECT * +FROM users u +WHERE u.id IN + (SELECT s1.user_id + FROM sessions s1 + WHERE s1.user_id <> 1234 + AND (s1.ip IN + (SELECT s2.ip + FROM sessions s2 + WHERE s2.user_id = 1234 + GROUP BY s2.ip) + OR s1.cookie_identifier IN + (SELECT s3.cookie_identifier + FROM sessions s3 + WHERE s3.user_id = 1234 + GROUP BY s3.cookie_identifier)) + GROUP BY s1.user_id)""" + q1 = """SELECT * +FROM users u +WHERE EXISTS ( + SELECT + NULL + FROM sessions s1 + WHERE s1.user_id <> 1234 + AND u.id = s1.user_id + AND EXISTS ( + SELECT + NULL + FROM sessions s2 + WHERE s2.user_id = 1234 + AND (s1.ip = s2.ip + OR s1.cookie_identifier = s2.cookie_identifier + ) + ) + )""" + _assert_matches_rule(q0, q1, "spreadsheet_id_15") + + +def test_generate_spreadsheet_id_18(): + q0 = """SELECT DISTINCT ON (t.playerId) t.gzpId, t.pubCode, t.playerId, + COALESCE (p.preferenceValue,'en'), + s.segmentId +FROM userPlayerIdMap t LEFT JOIN + userPreferences p + ON t.gzpId = p.gzpId LEFT JOIN + segment s + ON t.gzpId = s.gzpId +WHERE t.pubCode IN ('hyrmas','ayqioa','rj49as99') and + t.provider IN ('FCM','ONE_SIGNAL') and + s.segmentId IN (0,1,2,3,4,5,6) and + p.preferenceValue IN ('en','hi') +ORDER BY t.playerId desc;""" + q1 = """SELECT t.gzpId, t.pubCode, t.playerId, + COALESCE((SELECT p.preferenceValue + FROM userPreferences p + WHERE t.gzpId = p.gzpId AND + p.preferenceValue IN ('en', 'hi') + LIMIT 1 + ), 'en' + ), + (SELECT s.segmentId + FROM segment s + WHERE t.gzpId = s.gzpId AND + s.segmentId IN (0, 1, 2, 3, 4, 5, 6) + LIMIT 1 + ) +FROM userPlayerIdMap t +WHERE t.pubCode IN ('hyrmas', 'ayqioa', 'rj49as99') and + t.provider IN ('FCM', 'ONE_SIGNAL');""" + _assert_matches_expected( + q0, + q1, + "SELECT DISTINCT ON (.) <>, COALESCE(., ''), FROM LEFT JOIN ON <> LEFT JOIN ON <> WHERE <> AND <> AND <> ORDER BY DESC", + "SELECT <>, COALESCE((SELECT . FROM WHERE <> AND <> LIMIT ), ''), (SELECT FROM WHERE <> AND <> LIMIT ) FROM WHERE <>", + ) + + +def test_generate_spreadsheet_id_20(): + q0 = "SELECT * FROM (SELECT * FROM (SELECT NULL FROM EMP) WHERE N IS NULL) WHERE N IS NULL" + q1 = "SELECT * FROM (SELECT NULL FROM EMP) WHERE N IS NULL" + _assert_matches_rule(q0, q1, "spreadsheet_id_20") + + +def test_generate_spreadsheet_id_21(): + q0 = "SELECT * FROM (SELECT * FROM EMP AS t WHERE t.N IS NULL) AS t0 WHERE t0.N IS NULL" + q1 = "SELECT * FROM EMP AS t WHERE t.N IS NULL" + _assert_matches_expected( + q0, + q1, + "FROM (SELECT <> FROM WHERE <>) AS t0 WHERE t0. IS NULL", + "FROM WHERE <>", + ) diff --git a/tests/test_rule_parser_v2.py b/tests/test_rule_parser_v2.py index a25ce79..4f4ccfd 100644 --- a/tests/test_rule_parser_v2.py +++ b/tests/test_rule_parser_v2.py @@ -30,6 +30,7 @@ UnaryOperatorNode, ElementVariableNode, SetVariableNode, + VariableLiteralNode, WhenThenNode, WhereNode, ) @@ -87,7 +88,15 @@ def _assert_varnodes_declared(result: RuleParseResult) -> None: def _assert_no_internal_tokens(result: RuleParseResult) -> None: """No EV00x / SV00x tokens should survive in identifier-bearing AST fields.""" - internal_tokens = set(result.mapping.values()) + def _check_alias(label: str, field_name: str, value) -> None: + if isinstance(value, str): + assert not _TOKEN_RE.match(value), ( + f"{label} AST has raw internal token {value!r} as {field_name}" + ) + elif isinstance(value, ElementVariableNode): + assert not _TOKEN_RE.match(value.name), ( + f"{label} AST has raw internal token {value.name!r} inside ElementVariableNode at {field_name}" + ) for tree_label, tree in [("pattern", result.pattern_ast), ("rewrite", result.rewrite_ast)]: for n in _walk(tree): @@ -95,34 +104,22 @@ def _assert_no_internal_tokens(result: RuleParseResult) -> None: assert not _TOKEN_RE.match(n.name), ( f"{tree_label} AST has raw internal token {n.name!r} as ColumnNode.name" ) - if isinstance(n.alias, str): - assert not _TOKEN_RE.match(n.alias), ( - f"{tree_label} AST has raw internal token {n.alias!r} as ColumnNode.alias" - ) - if n.parent_alias in internal_tokens: - assert not _TOKEN_RE.match(n.parent_alias), ( - f"{tree_label} AST has raw internal token {n.parent_alias!r} " - f"as ColumnNode.parent_alias" - ) + _check_alias(tree_label, "ColumnNode.alias", n.alias) + _check_alias(tree_label, "ColumnNode.parent_alias", n.parent_alias) if isinstance(n, TableNode) and isinstance(n.name, str): assert not _TOKEN_RE.match(n.name), ( f"{tree_label} AST has raw internal token {n.name!r} as TableNode.name" ) - if isinstance(n.alias, str): - assert not _TOKEN_RE.match(n.alias), ( - f"{tree_label} AST has raw internal token {n.alias!r} as TableNode.alias" - ) + _check_alias(tree_label, "TableNode.alias", n.alias) if isinstance(n, SubqueryNode) and isinstance(n.alias, str): assert not _TOKEN_RE.match(n.alias), ( f"{tree_label} AST has raw internal token {n.alias!r} as SubqueryNode.alias" ) - if isinstance(n, FunctionNode) and isinstance(n.alias, str): - assert not _TOKEN_RE.match(n.alias), ( - f"{tree_label} AST has raw internal token {n.alias!r} as FunctionNode.alias" - ) + if isinstance(n, FunctionNode): + _check_alias(tree_label, "FunctionNode.alias", n.alias) # ═══════════════════════════════════════════════════════════════════════════════ @@ -364,14 +361,15 @@ def test_parse_ast_strpos_ilike_rule(): assert isinstance(lower, FunctionNode) and lower.name.lower() == "lower" assert isinstance(list(lower.children)[0], ElementVariableNode) assert list(lower.children)[0].name == "x" - assert isinstance(strpos_args[1], LiteralNode) + assert isinstance(strpos_args[1], VariableLiteralNode) + assert strpos_args[1].name == "s" # Rewrite: ILIKE rew = result.rewrite_ast assert isinstance(rew, FunctionNode) and rew.name.lower() == "ilike" ilike_args = list(rew.children) assert isinstance(ilike_args[0], ElementVariableNode) and ilike_args[0].name == "x" - assert isinstance(ilike_args[1], LiteralNode) - assert ilike_args[1].value == "%s%" + assert isinstance(ilike_args[1], VariableLiteralNode) + assert ilike_args[1].name == "s" and ilike_args[1].prefix == "%" and ilike_args[1].suffix == "%" def test_substitute_placeholders_limit_offset_string_tokens(): @@ -382,8 +380,8 @@ def test_substitute_placeholders_limit_offset_string_tokens(): off = RuleParserV2._substitute_placeholders( # type: ignore[arg-type] OffsetNode("EV002"), {"EV002": "y"} ) - assert isinstance(lim, LimitNode) and lim.limit == "x" - assert isinstance(off, OffsetNode) and off.offset == "y" + assert isinstance(lim, LimitNode) and isinstance(lim.limit, ElementVariableNode) and lim.limit.name == "x" + assert isinstance(off, OffsetNode) and isinstance(off.offset, ElementVariableNode) and off.offset.name == "y" def test_parse_substitutes_alias_fields(): @@ -470,13 +468,16 @@ def test_parse_where_scope_strips_select_and_from(): # ═══════════════════════════════════════════════════════════════════════════════ def test_parse_ast_from_scope(): + # After parser update: EV tokens in TABLE position are promoted to ElementVariableNode. + # The concrete alias "li" is not a variable, so it's dropped from the ElementVariableNode. + # The element variable "t" captures the whole table reference during matching. result = RuleParserV2.parse("FROM li", "FROM li") assert result.mapping == {"t": "EV001"} assert isinstance(result.pattern_ast, QueryNode) frm = next(c for c in result.pattern_ast.children if c.type == NodeType.FROM) assert isinstance(frm, FromNode) tab = list(frm.children)[0] - assert isinstance(tab, TableNode) and tab.name == "t" and tab.alias == "li" + assert isinstance(tab, ElementVariableNode) and tab.name == "t" def test_parse_from_scope_strips_select(): @@ -536,8 +537,8 @@ def test_parse_self_join_rule(): ) _assert_varnodes_declared(result) _assert_no_internal_tokens(result) - assert len(_find_all(result.pattern_ast, TableNode)) >= 2 - assert len(_find_all(result.rewrite_ast, TableNode)) >= 1 + assert len(_find_all(result.pattern_ast, ElementVariableNode)) >= 2 + assert len(_find_all(result.rewrite_ast, ElementVariableNode)) >= 1 pat_svs = [n for n in _walk(result.pattern_ast) if isinstance(n, SetVariableNode)] assert len(pat_svs) >= 2 # s1 and p1 @@ -632,16 +633,19 @@ def test_parse_set_variable_in_select_and_where(): # ═══════════════════════════════════════════════════════════════════════════════ def test_qualified_column_both_parts_substituted(): - """. — both parent_alias and name should become external names.""" + """. — both parent_alias and name should become external names (ElementVariableNode).""" result = RuleParserV2.parse(". = 1", ". = 1") _assert_varnodes_declared(result) _assert_no_internal_tokens(result) - cols = _find_all(result.pattern_ast, ColumnNode) - qualified = [c for c in cols if c.parent_alias is not None] + # When both parts are variables, _substitute_placeholders returns ElementVariableNode + # with parent_alias stored as ElementVariableNode (widened field type). + evars = _find_all(result.pattern_ast, ElementVariableNode) + qualified = [e for e in evars if e.parent_alias is not None] assert len(qualified) >= 1 - for c in qualified: - assert c.parent_alias in result.mapping - assert c.name in result.mapping + for e in qualified: + assert isinstance(e.parent_alias, ElementVariableNode) + assert e.parent_alias.name in result.mapping + assert e.name in result.mapping def test_qualified_column_only_parent_alias_is_var():