Skip to content

Commit a3a3c41

Browse files
feat: make build_kg return list of nodes & edges
1 parent 931c4cb commit a3a3c41

5 files changed

Lines changed: 34 additions & 23 deletions

File tree

graphgen/models/kg_builder/light_rag_kg_builder.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def merge_nodes(
9999
self,
100100
node_data: tuple[str, List[dict]],
101101
kg_instance: BaseGraphStorage,
102-
) -> None:
102+
) -> dict:
103103
entity_name, node_data = node_data
104104
entity_types = []
105105
source_ids = []
@@ -131,16 +131,18 @@ async def merge_nodes(
131131

132132
node_data = {
133133
"entity_type": entity_type,
134+
"entity_name": entity_name,
134135
"description": description,
135136
"source_id": source_id,
136137
}
137138
kg_instance.upsert_node(entity_name, node_data=node_data)
139+
return node_data
138140

139141
async def merge_edges(
140142
self,
141143
edges_data: tuple[Tuple[str, str], List[dict]],
142144
kg_instance: BaseGraphStorage,
143-
) -> None:
145+
) -> dict:
144146
(src_id, tgt_id), edge_data = edges_data
145147

146148
source_ids = []
@@ -175,11 +177,19 @@ async def merge_edges(
175177
f"({src_id}, {tgt_id})", description
176178
)
177179

180+
edge_data = {
181+
"src_id": src_id,
182+
"tgt_id": tgt_id,
183+
"description": description,
184+
"source_id": source_id, # for traceability
185+
}
186+
178187
kg_instance.upsert_edge(
179188
src_id,
180189
tgt_id,
181190
edge_data={"source_id": source_id, "description": description},
182191
)
192+
return edge_data
183193

184194
async def _handle_kg_summary(
185195
self,

graphgen/operators/build_kg/build_kg_service.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ def process(self, batch: pd.DataFrame) -> pd.DataFrame:
2929

3030
# consume the chunks and build kg
3131
nodes, edges = self.build_kg(docs)
32-
return pd.DataFrame([{"nodes": nodes, "edges": edges}])
32+
return pd.DataFrame(
33+
[{"node": node, "edge": []} for node in nodes]
34+
+ [{"node": [], "edge": edge} for edge in edges]
35+
)
3336

3437
def build_kg(self, chunks: List[Chunk]) -> tuple:
3538
"""
@@ -42,8 +45,8 @@ def build_kg(self, chunks: List[Chunk]) -> tuple:
4245
if chunk.type in ("image", "video", "table", "formula")
4346
]
4447

45-
nodes = {}
46-
edges = {}
48+
nodes = []
49+
edges = []
4750

4851
if len(text_chunks) == 0:
4952
logger.info("All text chunks are already in the storage")
@@ -55,8 +58,8 @@ def build_kg(self, chunks: List[Chunk]) -> tuple:
5558
chunks=text_chunks,
5659
max_loop=self.max_loop,
5760
)
58-
nodes.update(text_nodes)
59-
edges.update(text_edges)
61+
nodes += text_nodes
62+
edges += text_edges
6063
if len(mm_chunks) == 0:
6164
logger.info("All multi-modal chunks are already in the storage")
6265
else:
@@ -66,8 +69,8 @@ def build_kg(self, chunks: List[Chunk]) -> tuple:
6669
kg_instance=self.graph_storage,
6770
chunks=mm_chunks,
6871
)
69-
nodes.update(mm_nodes)
70-
edges.update(mm_edges)
72+
nodes += mm_nodes
73+
edges += mm_edges
7174

7275
self.graph_storage.index_done_callback()
7376
logger.info("Knowledge graph building completed.")

graphgen/operators/build_kg/build_mm_kg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ def build_mm_kg(
3737
for k, v in e.items():
3838
edges[tuple(sorted(k))].extend(v)
3939

40-
run_concurrent(
40+
nodes = run_concurrent(
4141
lambda kv: mm_builder.merge_nodes(kv, kg_instance=kg_instance),
4242
list(nodes.items()),
4343
desc="Inserting entities into storage",
4444
)
4545

46-
run_concurrent(
46+
edges = run_concurrent(
4747
lambda kv: mm_builder.merge_edges(kv, kg_instance=kg_instance),
4848
list(edges.items()),
4949
desc="Inserting relationships into storage",

graphgen/operators/build_kg/build_text_kg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ def build_text_kg(
3939
for k, v in e.items():
4040
edges[tuple(sorted(k))].extend(v)
4141

42-
run_concurrent(
42+
nodes = run_concurrent(
4343
lambda kv: kg_builder.merge_nodes(kv, kg_instance=kg_instance),
4444
list(nodes.items()),
4545
desc="Inserting entities into storage",
4646
)
4747

48-
run_concurrent(
48+
edges = run_concurrent(
4949
lambda kv: kg_builder.merge_edges(kv, kg_instance=kg_instance),
5050
list(edges.items()),
5151
desc="Inserting relationships into storage",

graphgen/operators/quiz/quiz_service.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections.abc import Iterable
2-
31
import pandas as pd
42

53
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper, BaseOperator
@@ -28,7 +26,7 @@ def __init__(
2826
)
2927
self.generator = QuizGenerator(self.llm_client)
3028

31-
def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
29+
def process(self, batch: pd.DataFrame) -> pd.DataFrame:
3230
data = batch.to_dict(orient="records")
3331
self.graph_storage.reload()
3432
return self.quiz(data)
@@ -62,22 +60,22 @@ async def _process_single_quiz(self, item: tuple) -> dict | None:
6260
logger.error("Error when quizzing description %s: %s", item, e)
6361
return None
6462

65-
def quiz(self, batch) -> Iterable[pd.DataFrame]:
63+
def quiz(self, batch) -> pd.DataFrame:
6664
"""
6765
Get all nodes and edges and quiz their descriptions using QuizGenerator.
6866
"""
6967
items = []
7068

7169
for item in batch:
72-
nodes = item.get("nodes", [])
73-
edges = item.get("edges", [])
70+
node_data = item.get("node", [])
71+
edge_data = item.get("edge", [])
7472

75-
for node_id, node_data in nodes.items():
76-
node_data = node_data[0]
73+
if node_data:
74+
node_id = node_data["entity_name"]
7775
desc = node_data["description"]
7876
items.append((node_id, desc))
79-
for edge_key, edge_data in edges.items():
80-
edge_data = edge_data[0]
77+
if edge_data:
78+
edge_key = (edge_data["src_id"], edge_data["tgt_id"])
8179
desc = edge_data["description"]
8280
items.append((edge_key, desc))
8381

0 commit comments

Comments
 (0)