|
3 | 3 | Expression handlers for converting SQL expressions to MongoDB query format |
4 | 4 | """ |
5 | 5 | import logging |
| 6 | +import re |
6 | 7 | import time |
7 | 8 | from abc import ABC, abstractmethod |
8 | 9 | from dataclasses import dataclass, field |
@@ -98,6 +99,27 @@ def has_children(ctx: Any) -> bool: |
98 | 99 | """Check if context has children""" |
99 | 100 | return hasattr(ctx, "children") and bool(ctx.children) |
100 | 101 |
|
| 102 | + @staticmethod |
| 103 | + def normalize_field_path(path: str) -> str: |
| 104 | + """Normalize jmspath/bracket notation to MongoDB dot notation. |
| 105 | +
|
| 106 | + Examples: |
| 107 | + items[0] -> items.0 |
| 108 | + items[1].name -> items.1.name |
| 109 | + arr['key'] or arr["key"] -> arr.key |
| 110 | + """ |
| 111 | + if not isinstance(path, str): |
| 112 | + return path |
| 113 | + |
| 114 | + s = path.strip() |
| 115 | + # Convert quoted bracket identifiers ["name"] or ['name'] -> .name |
| 116 | + s = re.sub(r"\[\s*['\"]([^'\"]+)['\"]\s*\]", r".\1", s) |
| 117 | + # Convert numeric bracket indexes [0] -> .0 |
| 118 | + s = re.sub(r"\[\s*(\d+)\s*\]", r".\1", s) |
| 119 | + # Collapse multiple dots and strip leading/trailing dots |
| 120 | + s = re.sub(r"\.{2,}", ".", s).strip(".") |
| 121 | + return s |
| 122 | + |
101 | 123 |
|
102 | 124 | class LoggingMixin: |
103 | 125 | """Mixin providing structured logging functionality""" |
@@ -358,21 +380,23 @@ def _extract_field_name(self, ctx: Any) -> str: |
358 | 380 | sql_keywords = ["IN(", "LIKE", "BETWEEN", "ISNULL", "ISNOTNULL"] |
359 | 381 | for keyword in sql_keywords: |
360 | 382 | if keyword in text: |
361 | | - return text.split(keyword, 1)[0].strip() |
| 383 | + candidate = text.split(keyword, 1)[0].strip() |
| 384 | + return self.normalize_field_path(candidate) |
362 | 385 |
|
363 | 386 | # Try operator-based splitting |
364 | 387 | operator = self._find_operator_in_text(text, COMPARISON_OPERATORS) |
365 | 388 | if operator: |
366 | 389 | parts = self._split_by_operator(text, operator) |
367 | 390 | if parts: |
368 | | - return parts[0].strip("'\"()") |
| 391 | + candidate = parts[0].strip("'\"()") |
| 392 | + return self.normalize_field_path(candidate) |
369 | 393 |
|
370 | 394 | # Fallback to children parsing |
371 | 395 | if self.has_children(ctx): |
372 | 396 | for child in ctx.children: |
373 | 397 | child_text = self.get_context_text(child) |
374 | 398 | if child_text not in COMPARISON_OPERATORS and not child_text.startswith(("'", '"')): |
375 | | - return child_text |
| 399 | + return self.normalize_field_path(child_text) |
376 | 400 |
|
377 | 401 | return "unknown_field" |
378 | 402 | except Exception as e: |
@@ -873,7 +897,7 @@ def handle(self, ctx: PartiQLParser.WhereClauseSelectContext) -> Dict[str, Any]: |
873 | 897 | # Visitor Handler Classes for AST Processing |
874 | 898 |
|
875 | 899 |
|
876 | | -class SelectHandler(BaseHandler): |
| 900 | +class SelectHandler(BaseHandler, ContextUtilsMixin): |
877 | 901 | """Handles SELECT statement parsing""" |
878 | 902 |
|
879 | 903 | def can_handle(self, ctx: Any) -> bool: |
@@ -903,6 +927,9 @@ def _extract_field_and_alias(self, item) -> Tuple[str, Optional[str]]: |
903 | 927 | # OR children[1] might be just symbolPrimitive (without AS) |
904 | 928 |
|
905 | 929 | field_name = item.children[0].getText() |
| 930 | + # Normalize bracket notation (jmspath) to Mongo dot notation |
| 931 | + field_name = self.normalize_field_path(field_name) |
| 932 | + |
906 | 933 | alias = None |
907 | 934 |
|
908 | 935 | if len(item.children) >= 2: |
|
0 commit comments