Skip to content

Commit a5153d5

Browse files
committed
feat: asgi and search service
1 parent d4b1c1f commit a5153d5

9 files changed

Lines changed: 605 additions & 419 deletions

File tree

Lines changed: 12 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
import logging
2-
from typing import Any, Dict, List
32

43
from flask import make_response, request
54
from flask_restx import fields, Resource
65

76
from application.api.answer.routes.base import answer_ns
8-
from application.core.settings import settings
9-
from application.storage.db.repositories.agents import AgentsRepository
10-
from application.storage.db.session import db_readonly
11-
from application.vectorstore.vector_creator import VectorCreator
7+
from application.services.search_service import (
8+
InvalidAPIKey,
9+
SearchFailed,
10+
search,
11+
)
1212

1313
logger = logging.getLogger(__name__)
1414

1515

1616
@answer_ns.route("/api/search")
1717
class SearchResource(Resource):
18-
"""Fast search endpoint for retrieving relevant documents"""
18+
"""Fast search endpoint for retrieving relevant documents."""
1919

2020
search_model = answer_ns.model(
2121
"SearchModel",
@@ -32,135 +32,24 @@ class SearchResource(Resource):
3232
},
3333
)
3434

35-
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
36-
"""Get source IDs connected to the API key/agent."""
37-
with db_readonly() as conn:
38-
agent_data = AgentsRepository(conn).find_by_key(api_key)
39-
if not agent_data:
40-
return []
41-
42-
source_ids: List[str] = []
43-
# extra_source_ids is a PG ARRAY(UUID) of source UUIDs.
44-
extra = agent_data.get("extra_source_ids") or []
45-
for src in extra:
46-
if src:
47-
source_ids.append(str(src))
48-
49-
if not source_ids:
50-
single = agent_data.get("source_id")
51-
if single:
52-
source_ids.append(str(single))
53-
54-
return source_ids
55-
56-
def _search_vectorstores(
57-
self, query: str, source_ids: List[str], chunks: int
58-
) -> List[Dict[str, Any]]:
59-
"""Search across vectorstores and return results"""
60-
if not source_ids:
61-
return []
62-
63-
results = []
64-
chunks_per_source = max(1, chunks // len(source_ids))
65-
seen_texts = set()
66-
67-
for source_id in source_ids:
68-
if not source_id or not source_id.strip():
69-
continue
70-
71-
try:
72-
docsearch = VectorCreator.create_vectorstore(
73-
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
74-
)
75-
docs = docsearch.search(query, k=chunks_per_source * 2)
76-
77-
for doc in docs:
78-
if len(results) >= chunks:
79-
break
80-
81-
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
82-
page_content = doc.page_content
83-
metadata = doc.metadata
84-
else:
85-
page_content = doc.get("text", doc.get("page_content", ""))
86-
metadata = doc.get("metadata", {})
87-
88-
# Skip duplicates
89-
text_hash = hash(page_content[:200])
90-
if text_hash in seen_texts:
91-
continue
92-
seen_texts.add(text_hash)
93-
94-
title = metadata.get(
95-
"title", metadata.get("post_title", "")
96-
)
97-
if not isinstance(title, str):
98-
title = str(title) if title else ""
99-
100-
# Clean up title
101-
if title:
102-
title = title.split("/")[-1]
103-
else:
104-
# Use filename or first part of content as title
105-
title = metadata.get("filename", page_content[:50] + "...")
106-
107-
source = metadata.get("source", source_id)
108-
109-
results.append({
110-
"text": page_content,
111-
"title": title,
112-
"source": source,
113-
})
114-
115-
if len(results) >= chunks:
116-
break
117-
118-
except Exception as e:
119-
logger.error(
120-
f"Error searching vectorstore {source_id}: {e}",
121-
exc_info=True,
122-
)
123-
continue
124-
125-
return results[:chunks]
126-
12735
@answer_ns.expect(search_model)
12836
@answer_ns.doc(description="Search for relevant documents based on query")
12937
def post(self):
130-
data = request.get_json()
38+
data = request.get_json() or {}
13139

13240
question = data.get("question")
13341
api_key = data.get("api_key")
13442
chunks = data.get("chunks", 5)
13543

13644
if not question:
13745
return make_response({"error": "question is required"}, 400)
138-
13946
if not api_key:
14047
return make_response({"error": "api_key is required"}, 400)
14148

142-
# Validate API key
143-
with db_readonly() as conn:
144-
agent = AgentsRepository(conn).find_by_key(api_key)
145-
if not agent:
146-
return make_response({"error": "Invalid API key"}, 401)
147-
14849
try:
149-
# Get sources connected to this API key
150-
source_ids = self._get_sources_from_api_key(api_key)
151-
152-
if not source_ids:
153-
return make_response([], 200)
154-
155-
# Perform search
156-
results = self._search_vectorstores(question, source_ids, chunks)
157-
158-
return make_response(results, 200)
159-
160-
except Exception as e:
161-
logger.error(
162-
f"/api/search - error: {str(e)}",
163-
extra={"error": str(e)},
164-
exc_info=True,
165-
)
50+
return make_response(search(api_key, question, chunks), 200)
51+
except InvalidAPIKey:
52+
return make_response({"error": "Invalid API key"}, 401)
53+
except SearchFailed:
54+
logger.exception("/api/search failed")
16655
return make_response({"error": "Search failed"}, 500)

application/asgi.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from a2wsgi import WSGIMiddleware
2+
3+
from application.app import app as flask_app
4+
5+
asgi_app = WSGIMiddleware(flask_app)

application/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ docx2txt==0.9
1414
ddgs>=8.0.0
1515
fast-ebook
1616
elevenlabs==2.43.0
17-
Flask==3.1.3
17+
Flask==3.1.1
1818
faiss-cpu==1.13.2
1919
fastmcp==3.2.4
2020
flask-restx==1.3.2

application/services/__init__.py

Whitespace-only changes.
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""Shared retrieval service used by the HTTP search route and the MCP tool.
2+
3+
Flask-free. Raises domain exceptions (``InvalidAPIKey``, ``SearchFailed``)
4+
that callers translate into their own wire protocol (HTTP status codes,
5+
MCP error responses, etc.).
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import logging
11+
from typing import Any, Dict, List
12+
13+
from application.core.settings import settings
14+
from application.storage.db.repositories.agents import AgentsRepository
15+
from application.storage.db.session import db_readonly
16+
from application.vectorstore.vector_creator import VectorCreator
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class InvalidAPIKey(Exception):
22+
"""The supplied ``api_key`` does not resolve to an agent."""
23+
24+
25+
class SearchFailed(Exception):
26+
"""Unexpected error during retrieval (e.g. DB outage). Caller maps to 5xx."""
27+
28+
29+
def _collect_source_ids(agent: Dict[str, Any]) -> List[str]:
30+
"""Extract the ordered list of source UUIDs to search.
31+
32+
Prefers ``extra_source_ids`` (PG ARRAY(UUID) of multi-source agents);
33+
falls back to the legacy single ``source_id`` field.
34+
"""
35+
source_ids: List[str] = []
36+
extra = agent.get("extra_source_ids") or []
37+
for src in extra:
38+
if src:
39+
source_ids.append(str(src))
40+
if not source_ids:
41+
single = agent.get("source_id")
42+
if single:
43+
source_ids.append(str(single))
44+
return source_ids
45+
46+
47+
def _search_sources(
48+
query: str, source_ids: List[str], chunks: int
49+
) -> List[Dict[str, Any]]:
50+
"""Search across each source's vectorstore and return up to ``chunks`` hits.
51+
52+
Per-source errors are logged and skipped so one broken index doesn't
53+
take down the whole search. Results are de-duplicated by content hash.
54+
"""
55+
if not source_ids:
56+
return []
57+
58+
results: List[Dict[str, Any]] = []
59+
chunks_per_source = max(1, chunks // len(source_ids))
60+
seen_texts: set[int] = set()
61+
62+
for source_id in source_ids:
63+
if not source_id or not source_id.strip():
64+
continue
65+
66+
try:
67+
docsearch = VectorCreator.create_vectorstore(
68+
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
69+
)
70+
docs = docsearch.search(query, k=chunks_per_source * 2)
71+
72+
for doc in docs:
73+
if len(results) >= chunks:
74+
break
75+
76+
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
77+
page_content = doc.page_content
78+
metadata = doc.metadata
79+
else:
80+
page_content = doc.get("text", doc.get("page_content", ""))
81+
metadata = doc.get("metadata", {})
82+
83+
text_hash = hash(page_content[:200])
84+
if text_hash in seen_texts:
85+
continue
86+
seen_texts.add(text_hash)
87+
88+
title = metadata.get("title", metadata.get("post_title", ""))
89+
if not isinstance(title, str):
90+
title = str(title) if title else ""
91+
92+
if title:
93+
title = title.split("/")[-1]
94+
else:
95+
title = metadata.get("filename", page_content[:50] + "...")
96+
97+
source = metadata.get("source", source_id)
98+
99+
results.append(
100+
{
101+
"text": page_content,
102+
"title": title,
103+
"source": source,
104+
}
105+
)
106+
107+
if len(results) >= chunks:
108+
break
109+
110+
except Exception as e:
111+
logger.error(
112+
f"Error searching vectorstore {source_id}: {e}",
113+
exc_info=True,
114+
)
115+
continue
116+
117+
return results[:chunks]
118+
119+
120+
def search(api_key: str, query: str, chunks: int = 5) -> List[Dict[str, Any]]:
121+
"""Resolve an agent by API key and search its sources.
122+
123+
Args:
124+
api_key: Agent API key (the opaque string stored on
125+
``agents.key`` in Postgres).
126+
query: Free-text search query.
127+
chunks: Max number of hits to return.
128+
129+
Returns:
130+
List of hit dicts with ``text``, ``title``, ``source`` keys.
131+
Empty list if the agent has no sources configured.
132+
133+
Raises:
134+
InvalidAPIKey: if ``api_key`` does not resolve to an agent.
135+
SearchFailed: on unexpected DB / infrastructure errors.
136+
"""
137+
try:
138+
with db_readonly() as conn:
139+
agent = AgentsRepository(conn).find_by_key(api_key)
140+
except Exception as e:
141+
raise SearchFailed("agent lookup failed") from e
142+
143+
if not agent:
144+
raise InvalidAPIKey()
145+
146+
source_ids = _collect_source_ids(agent)
147+
if not source_ids:
148+
return []
149+
150+
return _search_sources(query, source_ids, chunks)

0 commit comments

Comments
 (0)