1818import json
1919import re
2020import time
21+ from collections .abc import Callable
2122from dataclasses import dataclass , field
2223from datetime import datetime
2324from enum import Enum
4445from app .resilience import CircuitBreaker , CircuitBreakerConfig , compute_backoff_seconds
4546from db .connection import DatabaseManager
4647from db .executor import QueryResult
48+ from db .registry import get_database_registry
4749from db .schema import SchemaInfo , SchemaIntrospector
4850
4951logger = get_logger (__name__ )
@@ -103,6 +105,16 @@ def to_dict(self) -> dict[str, Any]:
103105 }
104106
105107
108+ @dataclass (frozen = True )
109+ class DatabaseContext :
110+ """Resolved database resources for a database id."""
111+
112+ database_id : str
113+ engine : Any
114+ dialect : str
115+ session_provider : Callable [[], Any ]
116+
117+
106118@dataclass
107119class ReasoningStep :
108120 """
@@ -927,6 +939,7 @@ async def validate_sql(
927939 async def _get_schema_context (self , database_id : str ) -> SchemaContext :
928940 """Get or create schema context for database."""
929941 start_time = time .perf_counter ()
942+ db_context = await self ._resolve_database_context (database_id )
930943 # Check cache
931944 if database_id in self ._schema_cache :
932945 return self ._schema_cache [database_id ]
@@ -940,7 +953,7 @@ async def _get_schema_context(self, database_id: str) -> SchemaContext:
940953 )
941954 serialized = cached_payload .get ("serialized_schema" )
942955 if not serialized :
943- introspector = SchemaIntrospector (self . _db_manager .engine )
956+ introspector = SchemaIntrospector (db_context .engine )
944957 serialized = introspector .serialize_for_prompt (schema_info )
945958 table_names = cached_payload .get ("table_names" ) or [
946959 table .name for table in schema_info .tables
@@ -962,7 +975,7 @@ async def _get_schema_context(self, database_id: str) -> SchemaContext:
962975 return context
963976
964977 # Create introspector and get schema
965- introspector = SchemaIntrospector (self . _db_manager .engine )
978+ introspector = SchemaIntrospector (db_context .engine )
966979 with self ._tracing .create_span (
967980 "schema.introspect" ,
968981 attributes = {"db.database_id" : database_id },
@@ -1004,6 +1017,35 @@ async def _get_schema_context(self, database_id: str) -> SchemaContext:
10041017 self ._schema_cache [database_id ] = context
10051018 return context
10061019
1020+ async def _resolve_database_context (
1021+ self ,
1022+ database_id : str | None ,
1023+ ) -> DatabaseContext :
1024+ resolved_id = database_id or "default"
1025+ settings = get_settings ()
1026+
1027+ if settings .multi_database .enabled :
1028+ registry = await get_database_registry ()
1029+ registered = registry .get_database (resolved_id )
1030+ dialect = (
1031+ registered .config .dialect .value
1032+ if registered .config .dialect is not None
1033+ else registered .engine .dialect .name
1034+ )
1035+ return DatabaseContext (
1036+ database_id = resolved_id ,
1037+ engine = registered .engine ,
1038+ dialect = dialect ,
1039+ session_provider = lambda : registry .session (resolved_id ),
1040+ )
1041+
1042+ return DatabaseContext (
1043+ database_id = resolved_id ,
1044+ engine = self ._db_manager .engine ,
1045+ dialect = self ._db_manager .dialect ,
1046+ session_provider = self ._db_manager .session ,
1047+ )
1048+
10071049 async def _get_few_shot_examples (
10081050 self ,
10091051 natural_query : str ,
@@ -1295,20 +1337,21 @@ async def execute_sql(
12951337 """
12961338 from db .executor import SafeQueryExecutor
12971339
1298- async with self ._db_manager .session () as session :
1340+ db_context = await self ._resolve_database_context (database_id )
1341+ async with db_context .session_provider () as session :
12991342 executor = SafeQueryExecutor (
13001343 session = session ,
13011344 allow_mutations = False ,
13021345 max_rows = max_rows ,
1303- database_id = database_id or "default" ,
1304- dialect = self . _db_manager .dialect ,
1346+ database_id = db_context . database_id ,
1347+ dialect = db_context .dialect ,
13051348 )
13061349
13071350 result : QueryResult = await executor .execute (sql )
13081351
13091352 logger .info (
13101353 "text2sql_execute_sql" ,
1311- database_id = database_id or "default" ,
1354+ database_id = db_context . database_id ,
13121355 row_count = result .row_count ,
13131356 )
13141357
0 commit comments