2828from langchain .text_splitter import RecursiveCharacterTextSplitter
2929
3030from src .chatbot .agents .utils .agent_helpers import llm_optional as sumarize_llm
31+ from src .chatbot .agents .utils .agent_retriever import retrieve_from_infinity_ragflow
3132
3233# from src.chatbot.db.redis_client import redis_manager
3334from src .chatbot .tools .utils .custom_crawl import (
4243from src .chatbot .tools .utils .tool_helpers import decode_string
4344from src .chatbot_log .chatbot_logger import logger
4445from src .config .core_config import settings
46+ from src .config .models import CollectionNames , SearchEngineTypes , VectorDBTypes
4547
4648colorama .init (strip = True )
4749
@@ -363,16 +365,39 @@ async def async_search(client, **kwargs) -> Tuple[str, List]:
363365
364366 agent_executor = kwargs ["agent_executor" ]
365367
366- visited_urls , contents = await visit_urls_extract (
367- url = url ,
368- query = query ,
369- agent_executor = agent_executor ,
370- about_application = about_application ,
371- do_not_visit_links = do_not_visit_links ,
372- client = client ,
373- )
368+ def extract_urls_from_content (refs ):
369+ visited_urls = []
370+ for r in refs :
371+ visited_urls .append (r .url_reference_web_uos )
372+ return visited_urls
373+
374+ SEARCH_TYPE = settings .application .search_engine_type
375+ if SEARCH_TYPE == SearchEngineTypes .RAGFlow_search :
376+ try :
377+ contents , ref = retrieve_from_infinity_ragflow (
378+ CollectionNames .WEB_UOS , query
379+ )
380+ visited_urls = extract_urls_from_content (ref )
381+ final_output = contents
382+ print ()
383+
384+ except Exception as e :
385+ logger .error (f"[RAGFlow] Error during retrieval: { e } " )
386+ final_output = ""
387+ visited_urls = []
374388
375- final_output = "\n " .join (contents )
389+ else :
390+
391+ visited_urls , contents = await visit_urls_extract (
392+ url = url ,
393+ query = query ,
394+ agent_executor = agent_executor ,
395+ about_application = about_application ,
396+ do_not_visit_links = do_not_visit_links ,
397+ client = client ,
398+ )
399+
400+ final_output = "\n " .join (contents )
376401
377402 if final_output :
378403 # For testing
@@ -383,7 +408,7 @@ async def async_search(client, **kwargs) -> Tuple[str, List]:
383408 logger .info (
384409 f"[SEARCH] Final output (search + prompt): { final_output_tokens } "
385410 )
386-
411+ # TODO: change the cache_key if the search engine is ragflow
387412 # Cache results
388413 if len (final_output ) > 20 :
389414 cache_value = str ((final_output , visited_urls ))
0 commit comments