Skip to content

Commit 028fa39

Browse files
committed
implement redis pool
1 parent 1b6cf2a commit 028fa39

4 files changed

Lines changed: 106 additions & 73 deletions

File tree

src/api/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from src.api.models import ChatCompletionRequest, ChatRequest, Message
2828
from src.api.translatations import _get_error_messages
2929
from src.chatbot.agents.graph import CampusManagementAgent
30+
from src.chatbot.db.redis_pool import redis_client
3031
from src.chatbot.prompt.prompt_date import get_current_date
3132
from src.chatbot.tools.utils.exceptions import ProgrammableSearchException
3233
from src.chatbot_log.chatbot_logger import logger
@@ -36,13 +37,15 @@
3637
@asynccontextmanager
3738
async def lifespan(app: FastAPI):
3839
# TODO: Move intizialization of singletons and settings here
40+
await redis_client.initialize()
3941
# Startup: eagerly initialize the singleton so the first request isn't slow
4042
agent = CampusManagementAgent()
4143
await agent._ensure_async_initialized()
4244
app.state.agent = agent
4345
yield
4446
# Shutdown: clean up Redis connection
4547
await agent.cleanup()
48+
await redis_client.cleanup()
4649

4750

4851
# TODO: Refactor key management (should be more robust)

src/chatbot/agents/graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ async def _create_graph(self):
210210
from src.chatbot.prompt.main import get_system_prompt
211211
from src.chatbot.prompt.prompt_date import get_current_date
212212

213+
# If graph is run without using fastapi, redis pool must be initialized manually and pool needs to be closed
213214
graph = CampusManagementAgent()
214215

215216
def print_graph(graph, filename="graph.png"):

src/chatbot/db/redis_pool.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import asyncio
2+
import threading
3+
from concurrent.futures import Future
4+
from typing import Any
5+
6+
import redis.asyncio as aioredis
7+
8+
from src.chatbot_log.chatbot_logger import logger
9+
10+
11+
class RedisClient:
12+
_instance = None
13+
_pool = None
14+
# _thread_lock = threading.Lock()
15+
16+
def __new__(cls):
17+
18+
if cls._instance is None:
19+
cls._instance = super().__new__(cls)
20+
cls._instance._pool = None
21+
cls._instance._lock = None
22+
return cls._instance
23+
24+
async def initialize(self, host: str = "redis", port: int = 6379):
25+
"""Create the connection pool once."""
26+
27+
if self._lock is None:
28+
self._lock = asyncio.Lock()
29+
async with self._lock:
30+
if self._pool is None:
31+
self._pool = aioredis.BlockingConnectionPool(
32+
host=host,
33+
port=port,
34+
timeout=15,
35+
decode_responses=True,
36+
max_connections=50,
37+
)
38+
logger.info("[REDIS] Connection pool initialized")
39+
40+
@property
41+
def client(self) -> aioredis.Redis:
42+
"""Return a client using the shared pool — no new connection created."""
43+
if self._pool is None:
44+
raise RuntimeError("RedisClient not initialized. Call initialize() first.")
45+
return aioredis.Redis(connection_pool=self._pool)
46+
47+
async def cleanup(self):
48+
"""Close the pool on shutdown."""
49+
if self._lock is None:
50+
return
51+
async with self._lock:
52+
if self._pool:
53+
await self._pool.disconnect()
54+
logger.info("[REDIS] Connection pool closed")
55+
self._pool = None
56+
57+
58+
# Singleton instance
59+
redis_client = RedisClient()

src/chatbot/tools/search_web_tool.py

Lines changed: 43 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
import aiohttp
1010
import dotenv
1111
import nest_asyncio
12-
import redis.asyncio as aioredis
1312
import redis.asyncio as redis
1413

1514
from src.chatbot.agents.models import RetrievalResult, ScrapeResult
1615
from src.chatbot.agents.utils.agent_helpers import model_registry
16+
from src.chatbot.db.redis_pool import redis_client
1717
from src.chatbot.tools.utils.exceptions import ProgrammableSearchException
1818
from src.chatbot.tools.utils.tool_helpers import decode_string
1919
from src.chatbot.utils.helpers import compute_search_num_tokens
@@ -194,30 +194,6 @@ async def visit_urls_extract(
194194
cache_key_prefix = f"{__name__}:visit_urls_extract:"
195195
cache_tasks = []
196196
async with aiohttp.ClientSession() as session:
197-
# Query Google search API
198-
# async with session.get(url) as response:
199-
# if response.status != 200:
200-
# raise ProgrammableSearchException(
201-
# f"Failed: Programmable Search Engine. Status: {response.status}"
202-
# )
203-
204-
# # Parse JSON response
205-
# dict_response = await response.json()
206-
207-
# # Check if there are results
208-
# total_results = dict_response.get("searchInformation", {}).get(
209-
# "totalResults", 0
210-
# )
211-
# if int(total_results) > 0:
212-
# links_search = [item["link"] for item in dict_response["items"]]
213-
# logger.debug(
214-
# f"[SEARCH] Search Engine returned {len(links_search)} results (links)"
215-
# )
216-
# else:
217-
# logger.warning(
218-
# f"[SEARCH] No results found by the search engine while requesting this URL: {url}"
219-
# )
220-
# return [], []
221197

222198
links_search = await _google_search(session, url)
223199
if not links_search:
@@ -311,60 +287,54 @@ async def visit_urls_extract(
311287

312288
async def async_search(**kwargs) -> Tuple[str, List]:
313289
"""Asynchronous search function that encapsulates the search functionality."""
290+
314291
try:
315-
# client = redis.Redis(host="redis", port=6379, decode_responses=True)
316-
# client = aioredis.Redis(host="redis", port=6379, decode_responses=True)
317-
# logger.debug("[REDIS] Async client created: %s", client)
318-
# await RedisPool.get_pool()
319-
# client = RedisPool.get_client()
320-
321-
async with aioredis.Redis(
322-
host="redis", port=6379, decode_responses=True
323-
) as client:
324-
logger.debug("[REDIS] Async client created: %s", client)
325-
query = kwargs.get("query", "")
326-
query_url = decode_string(query)
327-
url = SEARCH_URL + query_url
328-
do_not_visit_links = kwargs.get("do_not_visit_links", [])
329-
about_application = kwargs.get("about_application", False)
330-
331-
# -------------------------- cache lookup --------------------------
332-
cache_key = f"{__name__}:async_search:{url}"
333-
cached_content = await client.get(cache_key)
334-
if cached_content:
335-
logger.debug("[REDIS] Retrieved cached searched results (urls)")
336-
return RetrievalResult.from_json(cached_content)
337-
338-
logger.debug("[SEARCH] Cache miss – proceeding with live search")
339-
340-
visited_urls, contents = await visit_urls_extract(
341-
url=url,
342-
query=query,
343-
about_application=about_application,
344-
do_not_visit_links=do_not_visit_links,
345-
client=client,
346-
)
347292

348-
final_output = "\n".join(contents)
293+
client = redis_client.client
294+
logger.debug("[REDIS] Async client created: %s", client)
295+
query = kwargs.get("query", "")
296+
query_url = decode_string(query)
297+
url = SEARCH_URL + query_url
298+
do_not_visit_links = kwargs.get("do_not_visit_links", [])
299+
about_application = kwargs.get("about_application", False)
300+
301+
# -------------------------- cache lookup --------------------------
302+
cache_key = f"{__name__}:async_search:{url}"
303+
cached_content = await client.get(cache_key)
304+
if cached_content:
305+
logger.debug("[REDIS] Retrieved cached searched results (urls)")
306+
return RetrievalResult.from_json(cached_content)
307+
308+
logger.debug("[SEARCH] Cache miss – proceeding with live search")
309+
310+
visited_urls, contents = await visit_urls_extract(
311+
url=url,
312+
query=query,
313+
about_application=about_application,
314+
do_not_visit_links=do_not_visit_links,
315+
client=client,
316+
)
349317

350-
if final_output:
351-
# For testing
352-
final_output_tokens, final_search_tokens = compute_tokens(
353-
final_output, query
354-
)
355-
logger.info(f"[SEARCH] Search tokens: {final_search_tokens}")
356-
logger.info(
357-
f"[SEARCH] Final output (search + prompt): {final_output_tokens}"
358-
)
318+
final_output = "\n".join(contents)
359319

360-
retrieved = RetrievalResult(
361-
result_text=final_output, reference=visited_urls, search_query=query
320+
if final_output:
321+
# For testing
322+
final_output_tokens, final_search_tokens = compute_tokens(
323+
final_output, query
324+
)
325+
logger.info(f"[SEARCH] Search tokens: {final_search_tokens}")
326+
logger.info(
327+
f"[SEARCH] Final output (search + prompt): {final_output_tokens}"
362328
)
363-
# -------------------------- cache store ---------------------------
364-
if len(final_output) > 20:
365-
await client.setex(cache_key, TTL, retrieved.to_json())
366329

367-
return retrieved
330+
retrieved = RetrievalResult(
331+
result_text=final_output, reference=visited_urls, search_query=query
332+
)
333+
# -------------------------- cache store ---------------------------
334+
if len(final_output) > 20:
335+
await client.setex(cache_key, TTL, retrieved.to_json())
336+
337+
return retrieved
368338

369339
except redis.ConnectionError as e:
370340
logger.error(

0 commit comments

Comments
 (0)