|
1 | | -from collections.abc import Iterable |
2 | | - |
3 | 1 | import pandas as pd |
4 | 2 |
|
5 | 3 | from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper, BaseOperator |
@@ -28,7 +26,7 @@ def __init__( |
28 | 26 | ) |
29 | 27 | self.generator = QuizGenerator(self.llm_client) |
30 | 28 |
|
31 | | - def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: |
| 29 | + def process(self, batch: pd.DataFrame) -> pd.DataFrame: |
32 | 30 | data = batch.to_dict(orient="records") |
33 | 31 | self.graph_storage.reload() |
34 | 32 | return self.quiz(data) |
@@ -62,22 +60,22 @@ async def _process_single_quiz(self, item: tuple) -> dict | None: |
62 | 60 | logger.error("Error when quizzing description %s: %s", item, e) |
63 | 61 | return None |
64 | 62 |
|
65 | | - def quiz(self, batch) -> Iterable[pd.DataFrame]: |
| 63 | + def quiz(self, batch) -> pd.DataFrame: |
66 | 64 | """ |
67 | 65 | Get all nodes and edges and quiz their descriptions using QuizGenerator. |
68 | 66 | """ |
69 | 67 | items = [] |
70 | 68 |
|
71 | 69 | for item in batch: |
72 | | - nodes = item.get("nodes", []) |
73 | | - edges = item.get("edges", []) |
| 70 | + node_data = item.get("node", []) |
| 71 | + edge_data = item.get("edge", []) |
74 | 72 |
|
75 | | - for node_id, node_data in nodes.items(): |
76 | | - node_data = node_data[0] |
| 73 | + if node_data: |
| 74 | + node_id = node_data["entity_name"] |
77 | 75 | desc = node_data["description"] |
78 | 76 | items.append((node_id, desc)) |
79 | | - for edge_key, edge_data in edges.items(): |
80 | | - edge_data = edge_data[0] |
| 77 | + if edge_data: |
| 78 | + edge_key = (edge_data["src_id"], edge_data["tgt_id"]) |
81 | 79 | desc = edge_data["description"] |
82 | 80 | items.append((edge_key, desc)) |
83 | 81 |
|
|
0 commit comments