Skip to content

Commit f4f5291

Browse files
committed
test: cover multi-db schema registration
1 parent e578b46 commit f4f5291

3 files changed

Lines changed: 157 additions & 0 deletions

File tree

tests/unit/test_registry.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
get_database_registry,
2121
reset_database_registry,
2222
)
23+
from db.schema import SchemaInfo
2324

2425

2526
class TestDatabaseConfig:
@@ -280,6 +281,29 @@ async def test_get_adapter(self, registry: DatabaseRegistry) -> None:
280281

281282
await registry.close_all()
282283

284+
@pytest.mark.asyncio
285+
async def test_update_schema_stores_metadata(
286+
self, registry: DatabaseRegistry
287+
) -> None:
288+
"""Test storing schema metadata for a registered database."""
289+
config = DatabaseConfig(
290+
database_id="schema_test",
291+
connection_string="sqlite+aiosqlite:///:memory:",
292+
)
293+
registered = await registry.register_database(config, test_connection=False)
294+
295+
schema = SchemaInfo(
296+
database_id="schema_test",
297+
dialect="sqlite",
298+
tables=[],
299+
)
300+
registry.update_schema("schema_test", schema)
301+
302+
assert registered.schema is schema
303+
assert registered.schema_updated_at == schema.last_updated
304+
305+
await registry.close_all()
306+
283307
@pytest.mark.asyncio
284308
async def test_health_check(self, registry: DatabaseRegistry) -> None:
285309
"""Test health check on database."""

tests/unit/test_routes_schema.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""
2+
Unit tests for schema routes (Issue #31: Multi-DB registry wiring).
3+
"""
4+
5+
from unittest.mock import AsyncMock, MagicMock, patch
6+
7+
import pytest
8+
from fastapi import FastAPI
9+
from fastapi.testclient import TestClient
10+
11+
from app.routes import router
12+
from db.schema import SchemaInfo
13+
14+
15+
@pytest.fixture
16+
def app() -> FastAPI:
17+
"""Create test FastAPI app."""
18+
from app.error_handlers import setup_exception_handlers
19+
from app.security import limiter
20+
21+
test_app = FastAPI()
22+
test_app.state.limiter = limiter
23+
setup_exception_handlers(test_app)
24+
test_app.include_router(router)
25+
26+
return test_app
27+
28+
29+
@pytest.fixture
30+
def client(app: FastAPI, auth_headers: dict[str, str]) -> TestClient:
31+
"""Create test client."""
32+
client = TestClient(app)
33+
client.headers.update(auth_headers)
34+
return client
35+
36+
37+
class TestSchemaRoutes:
38+
"""Tests for schema registration endpoints."""
39+
40+
def test_register_schema_registers_and_caches(
41+
self,
42+
client: TestClient,
43+
) -> None:
44+
"""Ensure schema registration wires registry and cache."""
45+
schema_info = SchemaInfo(
46+
database_id="analytics",
47+
dialect="postgresql",
48+
tables=[],
49+
)
50+
51+
mock_introspector = MagicMock()
52+
mock_introspector.get_schema = AsyncMock(return_value=schema_info)
53+
mock_introspector.serialize_for_prompt.return_value = "schema"
54+
55+
registry = MagicMock()
56+
registry.database_count = 0
57+
registry.register_database = AsyncMock()
58+
registered = MagicMock()
59+
registered.engine = MagicMock()
60+
registry.register_database.return_value = registered
61+
62+
settings = MagicMock()
63+
settings.multi_database.enabled = True
64+
settings.multi_database.max_databases = 50
65+
settings.multi_database.default_pool_size = 5
66+
settings.multi_database.default_max_overflow = 10
67+
settings.multi_database.default_pool_timeout = 30
68+
settings.multi_database.allow_mutations = False
69+
settings.multi_database.require_connection_test = True
70+
settings.cache.enabled = True
71+
72+
with (
73+
patch("app.routes.get_settings", return_value=settings),
74+
patch("db.registry.get_database_registry", return_value=registry),
75+
patch("db.schema.SchemaIntrospector", return_value=mock_introspector),
76+
patch("app.cache.cache_schema", new_callable=AsyncMock) as cache_schema,
77+
):
78+
response = client.post(
79+
"/api/v1/schema/register",
80+
json={
81+
"database_id": "analytics",
82+
"connection_string": "sqlite:///test.db",
83+
"dialect": "postgresql",
84+
},
85+
)
86+
87+
assert response.status_code == 200
88+
payload = response.json()
89+
assert payload["status"] == "registered"
90+
assert payload["database_id"] == "analytics"
91+
assert payload["dialect"] == "postgresql"
92+
93+
registry.register_database.assert_awaited_once()
94+
args, kwargs = registry.register_database.call_args
95+
config = args[0]
96+
assert config.database_id == "analytics"
97+
assert config.connection_string == "sqlite:///test.db"
98+
assert config.pool_size == 5
99+
assert kwargs["test_connection"] is True
100+
101+
mock_introspector.get_schema.assert_awaited_once_with("analytics")
102+
cache_schema.assert_awaited_once()

tests/unit/test_text2sql_engine.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def mock_settings(self) -> MagicMock:
433433
)
434434
settings.cache = MagicMock(enabled=False)
435435
settings.huggingface = MagicMock(model_name="test-model")
436+
settings.multi_database = MagicMock(enabled=False)
436437
return settings
437438

438439
def test_initialization(
@@ -633,6 +634,35 @@ async def test_generate_with_retry_uses_fallback_when_circuit_open(
633634
assert result.metadata.get("fallback") is True
634635
assert any("circuit" in w for w in warnings)
635636

637+
@pytest.mark.asyncio
638+
async def test_resolve_database_context_uses_registry(
639+
self,
640+
mock_db_manager: MagicMock,
641+
mock_settings: MagicMock,
642+
) -> None:
643+
"""Test registry-based context resolution when enabled."""
644+
mock_settings.multi_database.enabled = True
645+
646+
registered = MagicMock()
647+
registered.engine = MagicMock()
648+
registered.config.dialect = MagicMock()
649+
registered.config.dialect.value = "postgresql"
650+
651+
registry = MagicMock()
652+
registry.get_database.return_value = registered
653+
registry.session.return_value = "session_ctx"
654+
655+
with (
656+
patch("app.text2sql_engine.get_settings", return_value=mock_settings),
657+
patch("app.text2sql_engine.get_database_registry", return_value=registry),
658+
):
659+
engine = Text2SQLEngine(mock_db_manager)
660+
context = await engine._resolve_database_context("analytics")
661+
662+
assert context.database_id == "analytics"
663+
assert context.dialect == "postgresql"
664+
assert context.session_provider() == "session_ctx"
665+
636666

637667
# =============================================================================
638668
# Test Global Engine Management
@@ -681,6 +711,7 @@ async def test_get_text2sql_engine(self) -> None:
681711
),
682712
cache=MagicMock(enabled=False),
683713
huggingface=MagicMock(model_name="test-model"),
714+
multi_database=MagicMock(enabled=False),
684715
)
685716

686717
engine = await get_text2sql_engine()

0 commit comments

Comments
 (0)