Skip to content

Commit 952a2d4

Browse files
perf: reduce system bubble time (#154)
* perf: reduce system bubble time * feat: make build_kg return list of nodes & edges * chore: upgrade ray to 2.53.0
1 parent 0c91d7f commit 952a2d4

10 files changed

Lines changed: 92 additions & 71 deletions

File tree

examples/generate/generate_aggregated_qa/aggregated_config.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,14 @@ nodes:
3434

3535
- id: quiz
3636
op_name: quiz
37-
type: aggregate
37+
type: map_batch
3838
dependencies:
3939
- build_kg
4040
execution_params:
4141
replicas: 1
4242
batch_size: 128
4343
params:
4444
quiz_samples: 2 # number of quiz samples to generate
45-
concurrency_limit: 200
4645

4746
- id: judge
4847
op_name: judge

graphgen/models/kg_builder/light_rag_kg_builder.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def merge_nodes(
9999
self,
100100
node_data: tuple[str, List[dict]],
101101
kg_instance: BaseGraphStorage,
102-
) -> None:
102+
) -> dict:
103103
entity_name, node_data = node_data
104104
entity_types = []
105105
source_ids = []
@@ -131,16 +131,18 @@ async def merge_nodes(
131131

132132
node_data = {
133133
"entity_type": entity_type,
134+
"entity_name": entity_name,
134135
"description": description,
135136
"source_id": source_id,
136137
}
137138
kg_instance.upsert_node(entity_name, node_data=node_data)
139+
return node_data
138140

139141
async def merge_edges(
140142
self,
141143
edges_data: tuple[Tuple[str, str], List[dict]],
142144
kg_instance: BaseGraphStorage,
143-
) -> None:
145+
) -> dict:
144146
(src_id, tgt_id), edge_data = edges_data
145147

146148
source_ids = []
@@ -175,11 +177,19 @@ async def merge_edges(
175177
f"({src_id}, {tgt_id})", description
176178
)
177179

180+
edge_data = {
181+
"src_id": src_id,
182+
"tgt_id": tgt_id,
183+
"description": description,
184+
"source_id": source_id, # for traceability
185+
}
186+
178187
kg_instance.upsert_edge(
179188
src_id,
180189
tgt_id,
181190
edge_data={"source_id": source_id, "description": description},
182191
)
192+
return edge_data
183193

184194
async def _handle_kg_summary(
185195
self,

graphgen/models/partitioner/ece_partitioner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import random
23
from collections import deque
34
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
@@ -34,17 +35,18 @@ def _sort_units(units: list, edge_sampling: str) -> list:
3435
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
3536
:return: sorted units
3637
"""
38+
default_loss = -math.log(0.1)
3739
if edge_sampling == "random":
3840
random.shuffle(units)
3941
elif edge_sampling == "min_loss":
4042
units = sorted(
4143
units,
42-
key=lambda x: x[-1]["loss"],
44+
key=lambda x: x[-1].get("loss", default_loss),
4345
)
4446
elif edge_sampling == "max_loss":
4547
units = sorted(
4648
units,
47-
key=lambda x: x[-1]["loss"],
49+
key=lambda x: x[-1].get("loss", default_loss),
4850
reverse=True,
4951
)
5052
else:
@@ -142,7 +144,7 @@ def _add_unit(u):
142144
return Community(
143145
id=seed_unit[1],
144146
nodes=list(community_nodes.keys()),
145-
edges=[tuple(sorted(e)) for e in community_edges]
147+
edges=[tuple(sorted(e)) for e in community_edges],
146148
)
147149

148150
for unit in tqdm(all_units, desc="ECE partition"):

graphgen/operators/build_kg/build_kg_service.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@ def process(self, batch: pd.DataFrame) -> pd.DataFrame:
2828
docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs]
2929

3030
# consume the chunks and build kg
31-
self.build_kg(docs)
32-
return pd.DataFrame([{"status": "kg_building_completed"}])
31+
nodes, edges = self.build_kg(docs)
32+
return pd.DataFrame(
33+
[{"node": node, "edge": []} for node in nodes]
34+
+ [{"node": [], "edge": edge} for edge in edges]
35+
)
3336

34-
def build_kg(self, chunks: List[Chunk]) -> None:
37+
def build_kg(self, chunks: List[Chunk]) -> tuple:
3538
"""
3639
Build knowledge graph (KG) and merge into kg_instance
3740
"""
@@ -42,24 +45,34 @@ def build_kg(self, chunks: List[Chunk]) -> None:
4245
if chunk.type in ("image", "video", "table", "formula")
4346
]
4447

48+
nodes = []
49+
edges = []
50+
4551
if len(text_chunks) == 0:
4652
logger.info("All text chunks are already in the storage")
4753
else:
4854
logger.info("[Text Entity and Relation Extraction] processing ...")
49-
build_text_kg(
55+
text_nodes, text_edges = build_text_kg(
5056
llm_client=self.llm_client,
5157
kg_instance=self.graph_storage,
5258
chunks=text_chunks,
5359
max_loop=self.max_loop,
5460
)
61+
nodes += text_nodes
62+
edges += text_edges
5563
if len(mm_chunks) == 0:
5664
logger.info("All multi-modal chunks are already in the storage")
5765
else:
5866
logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
59-
build_mm_kg(
67+
mm_nodes, mm_edges = build_mm_kg(
6068
llm_client=self.llm_client,
6169
kg_instance=self.graph_storage,
6270
chunks=mm_chunks,
6371
)
72+
nodes += mm_nodes
73+
edges += mm_edges
6474

6575
self.graph_storage.index_done_callback()
76+
logger.info("Knowledge graph building completed.")
77+
78+
return nodes, edges

graphgen/operators/build_kg/build_mm_kg.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def build_mm_kg(
1212
llm_client: BaseLLMWrapper,
1313
kg_instance: BaseGraphStorage,
1414
chunks: List[Chunk],
15-
):
15+
) -> tuple:
1616
"""
1717
Build multi-modal KG and merge into kg_instance
1818
:param llm_client: Synthesizer LLM model to extract entities and relationships
@@ -37,14 +37,16 @@ def build_mm_kg(
3737
for k, v in e.items():
3838
edges[tuple(sorted(k))].extend(v)
3939

40-
run_concurrent(
40+
nodes = run_concurrent(
4141
lambda kv: mm_builder.merge_nodes(kv, kg_instance=kg_instance),
4242
list(nodes.items()),
4343
desc="Inserting entities into storage",
4444
)
4545

46-
run_concurrent(
46+
edges = run_concurrent(
4747
lambda kv: mm_builder.merge_edges(kv, kg_instance=kg_instance),
4848
list(edges.items()),
4949
desc="Inserting relationships into storage",
5050
)
51+
52+
return nodes, edges

graphgen/operators/build_kg/build_text_kg.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def build_text_kg(
1313
kg_instance: BaseGraphStorage,
1414
chunks: List[Chunk],
1515
max_loop: int = 3,
16-
):
16+
) -> tuple:
1717
"""
1818
:param llm_client: Synthesizer LLM model to extract entities and relationships
1919
:param kg_instance
@@ -39,14 +39,16 @@ def build_text_kg(
3939
for k, v in e.items():
4040
edges[tuple(sorted(k))].extend(v)
4141

42-
run_concurrent(
42+
nodes = run_concurrent(
4343
lambda kv: kg_builder.merge_nodes(kv, kg_instance=kg_instance),
4444
list(nodes.items()),
4545
desc="Inserting entities into storage",
4646
)
4747

48-
run_concurrent(
48+
edges = run_concurrent(
4949
lambda kv: kg_builder.merge_edges(kv, kg_instance=kg_instance),
5050
list(edges.items()),
5151
desc="Inserting relationships into storage",
5252
)
53+
54+
return nodes, edges

graphgen/operators/evaluate/evaluate_service.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ async def _process_single_qa(self, item: dict[str, Any]) -> dict[str, Any]:
9595
answer=str(item.get("answer", "")),
9696
)
9797
if not qa_pair.question or not qa_pair.answer:
98-
self.logger.error("Empty question or answer, skipping.")
98+
logger.error("Empty question or answer, skipping.")
9999
return {}
100100
except Exception as e:
101-
self.logger.error("Error in QAPair creation: %s", str(e))
101+
logger.error("Error in QAPair creation: %s", str(e))
102102
return {}
103103

104104
for metric, evaluator in self.qa_evaluators.items():
@@ -110,7 +110,7 @@ async def _process_single_qa(self, item: dict[str, Any]) -> dict[str, Any]:
110110
else:
111111
item[metric] = float(score)
112112
except Exception as e:
113-
self.logger.error("Error in %s evaluation: %s", metric, str(e))
113+
logger.error("Error in %s evaluation: %s", metric, str(e))
114114
item[metric] = None
115115
return item
116116

@@ -136,7 +136,7 @@ def transform_messages_format(items: list[dict]) -> list[dict]:
136136
return []
137137

138138
if not self.qa_evaluators:
139-
self.logger.warning("No QA evaluators initialized, skipping QA evaluation")
139+
logger.warning("No QA evaluators initialized, skipping QA evaluation")
140140
return []
141141

142142
items = transform_messages_format(items)
@@ -155,11 +155,11 @@ def _evaluate_kg(self) -> Dict[str, Any]:
155155

156156
for metric, evaluator in self.kg_evaluators.items():
157157
try:
158-
self.logger.info("Running %s evaluation...", metric)
158+
logger.info("Running %s evaluation...", metric)
159159
score = evaluator.evaluate()
160160
results[metric] = score
161161
except Exception as e:
162-
self.logger.error("Error in %s evaluation: %s", metric, str(e))
162+
logger.error("Error in %s evaluation: %s", metric, str(e))
163163
results[metric] = {"error": str(e)}
164164
return results
165165

graphgen/operators/partition/partition_service.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,13 @@ def partition(self) -> Iterable[pd.DataFrame]:
7979
else:
8080
raise ValueError(f"Unsupported partition method: {method}")
8181

82-
communities = partitioner.partition(g=self.kg_instance, **method_params)
82+
communities: Iterable = partitioner.partition(
83+
g=self.kg_instance, **method_params
84+
)
8385

86+
count = 0
8487
for community in communities:
88+
count += 1
8589
batch = partitioner.community2batch(community, g=self.kg_instance)
8690
batch = self._attach_additional_data_to_node(batch)
8791

@@ -91,6 +95,7 @@ def partition(self) -> Iterable[pd.DataFrame]:
9195
"edges": [batch[1]],
9296
}
9397
)
98+
logger.info("Total communities partitioned: %d", count)
9499

95100
def _pre_tokenize(self) -> None:
96101
"""Pre-tokenize all nodes and edges to add token length information."""
Lines changed: 33 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections.abc import Iterable
2-
31
import pandas as pd
42

53
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper, BaseOperator
@@ -15,7 +13,6 @@ def __init__(
1513
graph_backend: str = "kuzu",
1614
kv_backend: str = "rocksdb",
1715
quiz_samples: int = 1,
18-
concurrency_limit: int = 200,
1916
):
2017
super().__init__(working_dir=working_dir, op_name="quiz_service")
2118
self.quiz_samples = quiz_samples
@@ -28,21 +25,16 @@ def __init__(
2825
backend=kv_backend, working_dir=working_dir, namespace="quiz"
2926
)
3027
self.generator = QuizGenerator(self.llm_client)
31-
self.concurrency_limit = concurrency_limit
3228

33-
def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
34-
# this operator does not consume any batch data
35-
# but for compatibility we keep the interface
36-
_ = batch.to_dict(orient="records")
29+
def process(self, batch: pd.DataFrame) -> pd.DataFrame:
30+
data = batch.to_dict(orient="records")
3731
self.graph_storage.reload()
38-
yield from self.quiz()
32+
return self.quiz(data)
3933

4034
async def _process_single_quiz(self, item: tuple) -> dict | None:
4135
# if quiz in quiz_storage exists already, directly get it
4236
index, desc = item
4337
_quiz_id = compute_dict_hash({"index": index, "description": desc})
44-
if self.quiz_storage.get_by_id(_quiz_id):
45-
return None
4638

4739
tasks = []
4840
for i in range(self.quiz_samples):
@@ -68,47 +60,43 @@ async def _process_single_quiz(self, item: tuple) -> dict | None:
6860
logger.error("Error when quizzing description %s: %s", item, e)
6961
return None
7062

71-
def quiz(self) -> Iterable[pd.DataFrame]:
63+
def quiz(self, batch) -> pd.DataFrame:
7264
"""
7365
Get all nodes and edges and quiz their descriptions using QuizGenerator.
7466
"""
75-
edges = self.graph_storage.get_all_edges()
76-
nodes = self.graph_storage.get_all_nodes()
77-
7867
items = []
7968

80-
for edge in edges:
81-
edge_data = edge[2]
82-
desc = edge_data["description"]
83-
items.append(((edge[0], edge[1]), desc))
69+
for item in batch:
70+
node_data = item.get("node", [])
71+
edge_data = item.get("edge", [])
8472

85-
for node in nodes:
86-
node_data = node[1]
87-
desc = node_data["description"]
88-
items.append((node[0], desc))
73+
if node_data:
74+
node_id = node_data["entity_name"]
75+
desc = node_data["description"]
76+
items.append((node_id, desc))
77+
if edge_data:
78+
edge_key = (edge_data["src_id"], edge_data["tgt_id"])
79+
desc = edge_data["description"]
80+
items.append((edge_key, desc))
8981

9082
logger.info("Total descriptions to quiz: %d", len(items))
9183

92-
for i in range(0, len(items), self.concurrency_limit):
93-
batch_items = items[i : i + self.concurrency_limit]
94-
batch_results = run_concurrent(
95-
self._process_single_quiz,
96-
batch_items,
97-
desc=f"Quizzing descriptions ({i} / {i + len(batch_items)})",
98-
unit="description",
99-
)
84+
results = run_concurrent(
85+
self._process_single_quiz,
86+
items,
87+
desc=f"Quizzing batch of {len(items)} descriptions",
88+
unit="description",
89+
)
90+
valid_results = [res for res in results if res]
10091

101-
final_results = []
102-
for new_result in batch_results:
103-
if new_result:
104-
self.quiz_storage.upsert(
105-
{
106-
new_result["_quiz_id"]: {
107-
"description": new_result["description"],
108-
"quizzes": new_result["quizzes"],
109-
}
110-
}
111-
)
112-
final_results.append(new_result)
113-
self.quiz_storage.index_done_callback()
114-
yield pd.DataFrame(final_results)
92+
for res in valid_results:
93+
self.quiz_storage.upsert(
94+
{
95+
res["_quiz_id"]: {
96+
"description": res["description"],
97+
"quizzes": res["quizzes"],
98+
}
99+
}
100+
)
101+
self.quiz_storage.index_done_callback()
102+
return pd.DataFrame(valid_results)

0 commit comments

Comments
 (0)