Skip to content

Commit e578b46

Browse files
committed
feat(db): wire schema registry and multi-db routing
1 parent ba42d50 commit e578b46

4 files changed

Lines changed: 199 additions & 18 deletions

File tree

app/routes.py

Lines changed: 102 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -473,8 +473,17 @@ async def get_schema(
473473
try:
474474
from db.schema import SchemaIntrospector
475475

476-
db_manager = await get_database()
477-
introspector = SchemaIntrospector(db_manager.engine)
476+
settings = get_settings()
477+
if settings.multi_database.enabled:
478+
from db.registry import get_database_registry
479+
480+
registry = await get_database_registry()
481+
introspector = SchemaIntrospector(
482+
engine_provider=registry.get_engine,
483+
)
484+
else:
485+
db_manager = await get_database()
486+
introspector = SchemaIntrospector(db_manager.engine)
478487
schema = await introspector.get_schema(database_id)
479488

480489
return SchemaResponse(
@@ -502,7 +511,7 @@ async def register_schema(
502511
request: Request,
503512
response: Response,
504513
schema_request: SchemaRequest,
505-
) -> dict[str, str]:
514+
) -> dict[str, Any]:
506515
"""
507516
Register a new database schema.
508517
@@ -530,8 +539,96 @@ async def register_schema(
530539
dialect=schema_request.dialect,
531540
)
532541

533-
# TODO: Implement actual registration
534-
return {"status": "registered", "database_id": schema_request.database_id}
542+
settings = get_settings()
543+
544+
try:
545+
from app.cache import cache_schema
546+
from db.schema import SchemaIntrospector
547+
548+
registry = None
549+
550+
if settings.multi_database.enabled:
551+
from db.dialects import SQLDialect
552+
from db.registry import DatabaseConfig, get_database_registry
553+
554+
registry = await get_database_registry()
555+
if registry.database_count >= settings.multi_database.max_databases:
556+
raise ValidationException(
557+
message=(
558+
"Maximum database limit reached "
559+
f"({settings.multi_database.max_databases})"
560+
),
561+
validation_errors=[
562+
{"field": "database_id", "error": "Registry is full"}
563+
],
564+
)
565+
566+
dialect: SQLDialect | None = None
567+
if schema_request.dialect:
568+
try:
569+
dialect = SQLDialect(schema_request.dialect.lower())
570+
except ValueError:
571+
raise ValidationException(
572+
message=f"Invalid dialect: {schema_request.dialect}",
573+
validation_errors=[
574+
{
575+
"field": "dialect",
576+
"error": (
577+
"Supported: "
578+
f"{SQLDialect.supported_dialects()}"
579+
),
580+
}
581+
],
582+
) from None
583+
584+
config = DatabaseConfig(
585+
database_id=schema_request.database_id,
586+
connection_string=schema_request.connection_string,
587+
dialect=dialect,
588+
pool_size=settings.multi_database.default_pool_size,
589+
max_overflow=settings.multi_database.default_max_overflow,
590+
pool_timeout=settings.multi_database.default_pool_timeout,
591+
read_only=not settings.multi_database.allow_mutations,
592+
)
593+
594+
registered = await registry.register_database(
595+
config,
596+
test_connection=settings.multi_database.require_connection_test,
597+
)
598+
introspector = SchemaIntrospector(registered.engine)
599+
else:
600+
db_manager = await get_database()
601+
introspector = SchemaIntrospector(db_manager.engine)
602+
603+
schema = await introspector.get_schema(schema_request.database_id)
604+
serialized = introspector.serialize_for_prompt(schema)
605+
if registry is not None:
606+
registry.update_schema(schema_request.database_id, schema)
607+
608+
if settings.cache.enabled:
609+
await cache_schema(
610+
schema_request.database_id,
611+
{
612+
"schema_info": schema.to_dict(),
613+
"serialized_schema": serialized,
614+
"table_names": [table.name for table in schema.tables],
615+
"dialect": schema.dialect,
616+
},
617+
)
618+
619+
return {
620+
"status": "registered",
621+
"database_id": schema_request.database_id,
622+
"table_count": len(schema.tables),
623+
"dialect": schema.dialect,
624+
}
625+
except Exception as e:
626+
logger.error(
627+
"register_schema_error",
628+
database_id=schema_request.database_id,
629+
error=str(e),
630+
)
631+
raise
535632

536633

537634
# =============================================================================

app/text2sql_engine.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import re
2020
import time
21+
from collections.abc import Callable
2122
from dataclasses import dataclass, field
2223
from datetime import datetime
2324
from enum import Enum
@@ -44,6 +45,7 @@
4445
from app.resilience import CircuitBreaker, CircuitBreakerConfig, compute_backoff_seconds
4546
from db.connection import DatabaseManager
4647
from db.executor import QueryResult
48+
from db.registry import get_database_registry
4749
from db.schema import SchemaInfo, SchemaIntrospector
4850

4951
logger = get_logger(__name__)
@@ -103,6 +105,16 @@ def to_dict(self) -> dict[str, Any]:
103105
}
104106

105107

108+
@dataclass(frozen=True)
109+
class DatabaseContext:
110+
"""Resolved database resources for a database id."""
111+
112+
database_id: str
113+
engine: Any
114+
dialect: str
115+
session_provider: Callable[[], Any]
116+
117+
106118
@dataclass
107119
class ReasoningStep:
108120
"""
@@ -927,6 +939,7 @@ async def validate_sql(
927939
async def _get_schema_context(self, database_id: str) -> SchemaContext:
928940
"""Get or create schema context for database."""
929941
start_time = time.perf_counter()
942+
db_context = await self._resolve_database_context(database_id)
930943
# Check cache
931944
if database_id in self._schema_cache:
932945
return self._schema_cache[database_id]
@@ -940,7 +953,7 @@ async def _get_schema_context(self, database_id: str) -> SchemaContext:
940953
)
941954
serialized = cached_payload.get("serialized_schema")
942955
if not serialized:
943-
introspector = SchemaIntrospector(self._db_manager.engine)
956+
introspector = SchemaIntrospector(db_context.engine)
944957
serialized = introspector.serialize_for_prompt(schema_info)
945958
table_names = cached_payload.get("table_names") or [
946959
table.name for table in schema_info.tables
@@ -962,7 +975,7 @@ async def _get_schema_context(self, database_id: str) -> SchemaContext:
962975
return context
963976

964977
# Create introspector and get schema
965-
introspector = SchemaIntrospector(self._db_manager.engine)
978+
introspector = SchemaIntrospector(db_context.engine)
966979
with self._tracing.create_span(
967980
"schema.introspect",
968981
attributes={"db.database_id": database_id},
@@ -1004,6 +1017,35 @@ async def _get_schema_context(self, database_id: str) -> SchemaContext:
10041017
self._schema_cache[database_id] = context
10051018
return context
10061019

1020+
async def _resolve_database_context(
1021+
self,
1022+
database_id: str | None,
1023+
) -> DatabaseContext:
1024+
resolved_id = database_id or "default"
1025+
settings = get_settings()
1026+
1027+
if settings.multi_database.enabled:
1028+
registry = await get_database_registry()
1029+
registered = registry.get_database(resolved_id)
1030+
dialect = (
1031+
registered.config.dialect.value
1032+
if registered.config.dialect is not None
1033+
else registered.engine.dialect.name
1034+
)
1035+
return DatabaseContext(
1036+
database_id=resolved_id,
1037+
engine=registered.engine,
1038+
dialect=dialect,
1039+
session_provider=lambda: registry.session(resolved_id),
1040+
)
1041+
1042+
return DatabaseContext(
1043+
database_id=resolved_id,
1044+
engine=self._db_manager.engine,
1045+
dialect=self._db_manager.dialect,
1046+
session_provider=self._db_manager.session,
1047+
)
1048+
10071049
async def _get_few_shot_examples(
10081050
self,
10091051
natural_query: str,
@@ -1295,20 +1337,21 @@ async def execute_sql(
12951337
"""
12961338
from db.executor import SafeQueryExecutor
12971339

1298-
async with self._db_manager.session() as session:
1340+
db_context = await self._resolve_database_context(database_id)
1341+
async with db_context.session_provider() as session:
12991342
executor = SafeQueryExecutor(
13001343
session=session,
13011344
allow_mutations=False,
13021345
max_rows=max_rows,
1303-
database_id=database_id or "default",
1304-
dialect=self._db_manager.dialect,
1346+
database_id=db_context.database_id,
1347+
dialect=db_context.dialect,
13051348
)
13061349

13071350
result: QueryResult = await executor.execute(sql)
13081351

13091352
logger.info(
13101353
"text2sql_execute_sql",
1311-
database_id=database_id or "default",
1354+
database_id=db_context.database_id,
13121355
row_count=result.row_count,
13131356
)
13141357

db/registry.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
convert_url_to_async,
3737
get_dialect_adapter,
3838
)
39+
from db.schema import SchemaInfo
3940

4041
logger = get_logger(__name__)
4142

@@ -160,14 +161,23 @@ class RegisteredDatabase:
160161
session_factory: async_sessionmaker[AsyncSession]
161162
adapter: DialectAdapter
162163
health: DatabaseHealth
164+
schema: SchemaInfo | None = None
165+
schema_updated_at: datetime | None = None
163166
registered_at: datetime = field(default_factory=datetime.now)
164167

165168
def to_dict(self) -> dict[str, Any]:
166169
"""Convert to dictionary."""
170+
schema_metadata = None
171+
if self.schema:
172+
schema_metadata = {
173+
"table_count": len(self.schema.tables),
174+
"last_updated": self.schema.last_updated.isoformat(),
175+
}
167176
return {
168177
**self.config.to_dict(),
169178
"registered_at": self.registered_at.isoformat(),
170179
"health": self.health.to_dict(),
180+
"schema": schema_metadata,
171181
}
172182

173183

@@ -400,6 +410,18 @@ def get_adapter(self, database_id: str) -> DialectAdapter:
400410
"""
401411
return self.get_database(database_id).adapter
402412

413+
def update_schema(self, database_id: str, schema: SchemaInfo) -> None:
414+
"""
415+
Store schema metadata for a registered database.
416+
417+
Args:
418+
database_id: Database identifier
419+
schema: Extracted schema information
420+
"""
421+
registered = self.get_database(database_id)
422+
registered.schema = schema
423+
registered.schema_updated_at = schema.last_updated
424+
403425
@asynccontextmanager
404426
async def session(self, database_id: str) -> AsyncGenerator[AsyncSession, None]:
405427
"""

db/schema.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
database schema information for use in SQL generation prompts.
66
"""
77

8+
from collections.abc import Callable
89
from dataclasses import dataclass, field
910
from datetime import datetime
1011
from typing import Any
@@ -181,14 +182,24 @@ class SchemaIntrospector:
181182
prompt_text = introspector.serialize_for_prompt(schema)
182183
"""
183184

184-
def __init__(self, engine: AsyncEngine) -> None:
185+
def __init__(
186+
self,
187+
engine: AsyncEngine | None = None,
188+
engine_provider: Callable[[str], AsyncEngine] | None = None,
189+
) -> None:
185190
"""
186191
Initialize schema introspector.
187192
188193
Args:
189194
engine: SQLAlchemy async engine
195+
engine_provider: Callable that resolves an engine per database ID
190196
"""
197+
if engine is None and engine_provider is None:
198+
raise ValueError(
199+
"SchemaIntrospector requires an engine or engine_provider."
200+
)
191201
self._engine = engine
202+
self._engine_provider = engine_provider
192203
self._cache: dict[str, SchemaInfo] = {}
193204

194205
async def get_schema(
@@ -214,9 +225,19 @@ async def get_schema(
214225
logger.info("introspecting_schema", database_id=database_id)
215226

216227
try:
217-
async with self._engine.connect() as conn:
228+
engine = (
229+
self._engine_provider(database_id)
230+
if self._engine_provider
231+
else self._engine
232+
)
233+
if engine is None:
234+
raise ValueError("SchemaIntrospector engine is not configured.")
235+
async with engine.connect() as conn:
218236
schema = await self._extract_schema(
219-
conn, database_id, include_row_counts
237+
conn,
238+
database_id,
239+
include_row_counts,
240+
engine.dialect.name,
220241
)
221242
self._cache[database_id] = schema
222243
return schema
@@ -237,6 +258,7 @@ async def _extract_schema(
237258
conn: AsyncConnection,
238259
database_id: str,
239260
include_row_counts: bool,
261+
dialect_name: str,
240262
) -> SchemaInfo:
241263
"""Extract schema from database connection."""
242264

@@ -319,12 +341,9 @@ def sync_inspect(connection: Any) -> dict[str, Any]:
319341
error=str(e),
320342
)
321343

322-
# Get dialect
323-
dialect = self._engine.dialect.name
324-
325344
return SchemaInfo(
326345
database_id=database_id,
327-
dialect=dialect,
346+
dialect=dialect_name,
328347
tables=tables,
329348
)
330349

0 commit comments

Comments
 (0)