1- from collections .abc import Iterable
2-
31import pandas as pd
42
53from graphgen .bases import BaseGraphStorage , BaseKVStorage , BaseLLMWrapper , BaseOperator
@@ -15,7 +13,6 @@ def __init__(
1513 graph_backend : str = "kuzu" ,
1614 kv_backend : str = "rocksdb" ,
1715 quiz_samples : int = 1 ,
18- concurrency_limit : int = 200 ,
1916 ):
2017 super ().__init__ (working_dir = working_dir , op_name = "quiz_service" )
2118 self .quiz_samples = quiz_samples
@@ -28,21 +25,16 @@ def __init__(
2825 backend = kv_backend , working_dir = working_dir , namespace = "quiz"
2926 )
3027 self .generator = QuizGenerator (self .llm_client )
31- self .concurrency_limit = concurrency_limit
3228
33- def process (self , batch : pd .DataFrame ) -> Iterable [pd .DataFrame ]:
34- # this operator does not consume any batch data
35- # but for compatibility we keep the interface
36- _ = batch .to_dict (orient = "records" )
29+ def process (self , batch : pd .DataFrame ) -> pd .DataFrame :
30+ data = batch .to_dict (orient = "records" )
3731 self .graph_storage .reload ()
38- yield from self .quiz ()
32+ return self .quiz (data )
3933
4034 async def _process_single_quiz (self , item : tuple ) -> dict | None :
4135 # if quiz in quiz_storage exists already, directly get it
4236 index , desc = item
4337 _quiz_id = compute_dict_hash ({"index" : index , "description" : desc })
44- if self .quiz_storage .get_by_id (_quiz_id ):
45- return None
4638
4739 tasks = []
4840 for i in range (self .quiz_samples ):
@@ -68,47 +60,43 @@ async def _process_single_quiz(self, item: tuple) -> dict | None:
6860 logger .error ("Error when quizzing description %s: %s" , item , e )
6961 return None
7062
71- def quiz (self ) -> Iterable [ pd .DataFrame ] :
63+ def quiz (self , batch ) -> pd .DataFrame :
7264 """
7365 Get all nodes and edges and quiz their descriptions using QuizGenerator.
7466 """
75- edges = self .graph_storage .get_all_edges ()
76- nodes = self .graph_storage .get_all_nodes ()
77-
7867 items = []
7968
80- for edge in edges :
81- edge_data = edge [2 ]
82- desc = edge_data ["description" ]
83- items .append (((edge [0 ], edge [1 ]), desc ))
69+ for item in batch :
70+ node_data = item .get ("node" , [])
71+ edge_data = item .get ("edge" , [])
8472
85- for node in nodes :
86- node_data = node [1 ]
87- desc = node_data ["description" ]
88- items .append ((node [0 ], desc ))
73+ if node_data :
74+ node_id = node_data ["entity_name" ]
75+ desc = node_data ["description" ]
76+ items .append ((node_id , desc ))
77+ if edge_data :
78+ edge_key = (edge_data ["src_id" ], edge_data ["tgt_id" ])
79+ desc = edge_data ["description" ]
80+ items .append ((edge_key , desc ))
8981
9082 logger .info ("Total descriptions to quiz: %d" , len (items ))
9183
92- for i in range (0 , len (items ), self .concurrency_limit ):
93- batch_items = items [i : i + self .concurrency_limit ]
94- batch_results = run_concurrent (
95- self ._process_single_quiz ,
96- batch_items ,
97- desc = f"Quizzing descriptions ({ i } / { i + len (batch_items )} )" ,
98- unit = "description" ,
99- )
84+ results = run_concurrent (
85+ self ._process_single_quiz ,
86+ items ,
87+ desc = f"Quizzing batch of { len (items )} descriptions" ,
88+ unit = "description" ,
89+ )
90+ valid_results = [res for res in results if res ]
10091
101- final_results = []
102- for new_result in batch_results :
103- if new_result :
104- self .quiz_storage .upsert (
105- {
106- new_result ["_quiz_id" ]: {
107- "description" : new_result ["description" ],
108- "quizzes" : new_result ["quizzes" ],
109- }
110- }
111- )
112- final_results .append (new_result )
113- self .quiz_storage .index_done_callback ()
114- yield pd .DataFrame (final_results )
92+ for res in valid_results :
93+ self .quiz_storage .upsert (
94+ {
95+ res ["_quiz_id" ]: {
96+ "description" : res ["description" ],
97+ "quizzes" : res ["quizzes" ],
98+ }
99+ }
100+ )
101+ self .quiz_storage .index_done_callback ()
102+ return pd .DataFrame (valid_results )
0 commit comments