66
77import ollama
88import pandas as pd
9+ from colorama import Fore
910from langchain .prompts import ChatPromptTemplate
1011from langchain .schema import AIMessage
1112from langchain .schema .runnable import Runnable
@@ -56,6 +57,16 @@ def get_line(lineno):
5657 return get_line
5758
5859
60+ class GraphRetrievalException (Exception ):
61+ def __init__ (self , message , query ):
62+ self .query = query
63+ super ().__init__ (message )
64+
65+
66+ class ContextAssemblerException (Exception ):
67+ pass
68+
69+
5970class GraphRAG (Runnable ):
6071 def __init__ (
6172 self ,
@@ -170,18 +181,27 @@ def query_graph(
170181 limit : Optional [int ] = None ,
171182 ) -> RunnableFn :
172183 def run (inputs : dict [str , Any ]) -> dict [str , Any ]:
184+ log .info (
185+ "Querying graph for matching entities (shuffle={}, limit={})" ,
186+ shuffle ,
187+ limit ,
188+ )
189+
173190 query = inputs ["query" ]
174191 params = inputs .get ("params" )
175192
176- context_df = pd .DataFrame (self .graph .query (query , params ))
193+ try :
194+ context_df = pd .DataFrame (self .graph .query (query , params ))
177195
178- if shuffle :
179- context_df = context_df .sample (frac = 1 )
196+ if shuffle :
197+ context_df = context_df .sample (frac = 1 )
180198
181- if limit is not None :
182- context_df = context_df .head (limit )
199+ if limit is not None :
200+ context_df = context_df .head (limit )
183201
184- return dict (context = context_df )
202+ return dict (context = context_df )
203+ except :
204+ raise GraphRetrievalException ("Graph query failed" , query = query )
185205
186206 return run
187207
@@ -214,6 +234,10 @@ def combined_knn(self, k: int) -> RunnableFn:
214234
215235 def run (inputs : dict [str , Any ]) -> dict [str , Any ]:
216236 context = inputs ["context" ]
237+
238+ if context is None or len (context ) == 0 :
239+ raise ContextAssemblerException ("Context not found" )
240+
217241 node_ids = context .node_id .to_list ()
218242
219243 for node_id in node_ids :
@@ -247,9 +271,17 @@ def nn_sample_shortest_paths(
247271 max_length : int ,
248272 ) -> RunnableFn :
249273 def run (inputs : dict [str , Any ]) -> dict [str , Any ]:
250- source_node_ids = inputs ["graph_retrieval" ]["context" ].node_id .to_list ()
274+ context = inputs ["graph_retrieval" ]["context" ]
275+
276+ if context is None or len (context ) == 0 :
277+ raise ContextAssemblerException ("Context not found" )
278+
279+ source_node_ids = context .node_id .to_list ()
251280 target_node_ids = inputs ["combined_knn" ]["knn" ]
252281
282+ if target_node_ids is None or len (target_node_ids ) == 0 :
283+ raise ContextAssemblerException ("Nearest neighbors not found" )
284+
253285 paths_df = self .ops .sample_shortest_paths (
254286 source_node_ids ,
255287 target_node_ids ,
@@ -271,6 +303,9 @@ def nn_random_walks(
271303 def run (inputs : dict [str , Any ]) -> dict [str , Any ]:
272304 source_node_ids = inputs ["combined_knn" ]["knn" ]
273305
306+ if source_node_ids is None or len (source_node_ids ) == 0 :
307+ raise ContextAssemblerException ("Nearest neighbors not found" )
308+
274309 paths_dfs = []
275310
276311 for source_node_id in source_node_ids :
@@ -412,7 +447,7 @@ def loader(self, stop_event: threading.Event):
412447
413448 time .sleep (0.1 )
414449
415- print (f" \r ⏱ { elapsed } \n " )
450+ print (" \b \b \b " , end = " \n \n " , flush = True )
416451
417452 def interactive (self ):
418453 config_path = user_config_path ("datalab" , "DataLabTechTV" )
@@ -461,9 +496,17 @@ def interactive(self):
461496 )
462497 loader_thread .start ()
463498
464- output = self .invoke (dict (user_query = user_query ))
465-
466- stop_event .set ()
467- loader_thread .join ()
468-
469- print (output ["context" ])
499+ try :
500+ response = self .invoke (dict (user_query = user_query ))
501+ stop_event .set ()
502+ loader_thread .join ()
503+ print (response .content )
504+ except GraphRetrievalException as e :
505+ stop_event .set ()
506+ loader_thread .join ()
507+ print (Fore .RED + "Error: " + str (e ))
508+ print (Fore .MAGENTA + e .query + Fore .RESET )
509+ except ContextAssemblerException as e :
510+ stop_event .set ()
511+ loader_thread .join ()
512+ print (Fore .RED + "Error: " + str (e ) + Fore .RESET )
0 commit comments