|
9 | 9 | import aiohttp |
10 | 10 | import dotenv |
11 | 11 | import nest_asyncio |
12 | | -import redis.asyncio as aioredis |
13 | 12 | import redis.asyncio as redis |
14 | 13 |
|
15 | 14 | from src.chatbot.agents.models import RetrievalResult, ScrapeResult |
16 | 15 | from src.chatbot.agents.utils.agent_helpers import model_registry |
| 16 | +from src.chatbot.db.redis_pool import redis_client |
17 | 17 | from src.chatbot.tools.utils.exceptions import ProgrammableSearchException |
18 | 18 | from src.chatbot.tools.utils.tool_helpers import decode_string |
19 | 19 | from src.chatbot.utils.helpers import compute_search_num_tokens |
@@ -194,30 +194,6 @@ async def visit_urls_extract( |
194 | 194 | cache_key_prefix = f"{__name__}:visit_urls_extract:" |
195 | 195 | cache_tasks = [] |
196 | 196 | async with aiohttp.ClientSession() as session: |
197 | | - # Query Google search API |
198 | | - # async with session.get(url) as response: |
199 | | - # if response.status != 200: |
200 | | - # raise ProgrammableSearchException( |
201 | | - # f"Failed: Programmable Search Engine. Status: {response.status}" |
202 | | - # ) |
203 | | - |
204 | | - # # Parse JSON response |
205 | | - # dict_response = await response.json() |
206 | | - |
207 | | - # # Check if there are results |
208 | | - # total_results = dict_response.get("searchInformation", {}).get( |
209 | | - # "totalResults", 0 |
210 | | - # ) |
211 | | - # if int(total_results) > 0: |
212 | | - # links_search = [item["link"] for item in dict_response["items"]] |
213 | | - # logger.debug( |
214 | | - # f"[SEARCH] Search Engine returned {len(links_search)} results (links)" |
215 | | - # ) |
216 | | - # else: |
217 | | - # logger.warning( |
218 | | - # f"[SEARCH] No results found by the search engine while requesting this URL: {url}" |
219 | | - # ) |
220 | | - # return [], [] |
221 | 197 |
|
222 | 198 | links_search = await _google_search(session, url) |
223 | 199 | if not links_search: |
@@ -311,60 +287,54 @@ async def visit_urls_extract( |
311 | 287 |
|
312 | 288 | async def async_search(**kwargs) -> Tuple[str, List]: |
313 | 289 | """Asynchronous search function that encapsulates the search functionality.""" |
| 290 | + |
314 | 291 | try: |
315 | | - # client = redis.Redis(host="redis", port=6379, decode_responses=True) |
316 | | - # client = aioredis.Redis(host="redis", port=6379, decode_responses=True) |
317 | | - # logger.debug("[REDIS] Async client created: %s", client) |
318 | | - # await RedisPool.get_pool() |
319 | | - # client = RedisPool.get_client() |
320 | | - |
321 | | - async with aioredis.Redis( |
322 | | - host="redis", port=6379, decode_responses=True |
323 | | - ) as client: |
324 | | - logger.debug("[REDIS] Async client created: %s", client) |
325 | | - query = kwargs.get("query", "") |
326 | | - query_url = decode_string(query) |
327 | | - url = SEARCH_URL + query_url |
328 | | - do_not_visit_links = kwargs.get("do_not_visit_links", []) |
329 | | - about_application = kwargs.get("about_application", False) |
330 | | - |
331 | | - # -------------------------- cache lookup -------------------------- |
332 | | - cache_key = f"{__name__}:async_search:{url}" |
333 | | - cached_content = await client.get(cache_key) |
334 | | - if cached_content: |
335 | | - logger.debug("[REDIS] Retrieved cached searched results (urls)") |
336 | | - return RetrievalResult.from_json(cached_content) |
337 | | - |
338 | | - logger.debug("[SEARCH] Cache miss – proceeding with live search") |
339 | | - |
340 | | - visited_urls, contents = await visit_urls_extract( |
341 | | - url=url, |
342 | | - query=query, |
343 | | - about_application=about_application, |
344 | | - do_not_visit_links=do_not_visit_links, |
345 | | - client=client, |
346 | | - ) |
347 | 292 |
|
348 | | - final_output = "\n".join(contents) |
| 293 | + client = redis_client.client |
| 294 | + logger.debug("[REDIS] Async client created: %s", client) |
| 295 | + query = kwargs.get("query", "") |
| 296 | + query_url = decode_string(query) |
| 297 | + url = SEARCH_URL + query_url |
| 298 | + do_not_visit_links = kwargs.get("do_not_visit_links", []) |
| 299 | + about_application = kwargs.get("about_application", False) |
| 300 | + |
| 301 | + # -------------------------- cache lookup -------------------------- |
| 302 | + cache_key = f"{__name__}:async_search:{url}" |
| 303 | + cached_content = await client.get(cache_key) |
| 304 | + if cached_content: |
| 305 | + logger.debug("[REDIS] Retrieved cached searched results (urls)") |
| 306 | + return RetrievalResult.from_json(cached_content) |
| 307 | + |
| 308 | + logger.debug("[SEARCH] Cache miss – proceeding with live search") |
| 309 | + |
| 310 | + visited_urls, contents = await visit_urls_extract( |
| 311 | + url=url, |
| 312 | + query=query, |
| 313 | + about_application=about_application, |
| 314 | + do_not_visit_links=do_not_visit_links, |
| 315 | + client=client, |
| 316 | + ) |
349 | 317 |
|
350 | | - if final_output: |
351 | | - # For testing |
352 | | - final_output_tokens, final_search_tokens = compute_tokens( |
353 | | - final_output, query |
354 | | - ) |
355 | | - logger.info(f"[SEARCH] Search tokens: {final_search_tokens}") |
356 | | - logger.info( |
357 | | - f"[SEARCH] Final output (search + prompt): {final_output_tokens}" |
358 | | - ) |
| 318 | + final_output = "\n".join(contents) |
359 | 319 |
|
360 | | - retrieved = RetrievalResult( |
361 | | - result_text=final_output, reference=visited_urls, search_query=query |
| 320 | + if final_output: |
| 321 | + # For testing |
| 322 | + final_output_tokens, final_search_tokens = compute_tokens( |
| 323 | + final_output, query |
| 324 | + ) |
| 325 | + logger.info(f"[SEARCH] Search tokens: {final_search_tokens}") |
| 326 | + logger.info( |
| 327 | + f"[SEARCH] Final output (search + prompt): {final_output_tokens}" |
362 | 328 | ) |
363 | | - # -------------------------- cache store --------------------------- |
364 | | - if len(final_output) > 20: |
365 | | - await client.setex(cache_key, TTL, retrieved.to_json()) |
366 | 329 |
|
367 | | - return retrieved |
| 330 | + retrieved = RetrievalResult( |
| 331 | + result_text=final_output, reference=visited_urls, search_query=query |
| 332 | + ) |
| 333 | + # -------------------------- cache store --------------------------- |
| 334 | + if len(final_output) > 20: |
| 335 | + await client.setex(cache_key, TTL, retrieved.to_json()) |
| 336 | + |
| 337 | + return retrieved |
368 | 338 |
|
369 | 339 | except redis.ConnectionError as e: |
370 | 340 | logger.error( |
|
0 commit comments