Skip to content

Commit 931c4cb

Browse files
perf: reduce system bubble time
1 parent 0c91d7f commit 931c4cb

8 files changed

Lines changed: 71 additions & 61 deletions

File tree

examples/generate/generate_aggregated_qa/aggregated_config.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,14 @@ nodes:
3434

3535
- id: quiz
3636
op_name: quiz
37-
type: aggregate
37+
type: map_batch
3838
dependencies:
3939
- build_kg
4040
execution_params:
4141
replicas: 1
4242
batch_size: 128
4343
params:
4444
quiz_samples: 2 # number of quiz samples to generate
45-
concurrency_limit: 200
4645

4746
- id: judge
4847
op_name: judge

graphgen/models/partitioner/ece_partitioner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import random
23
from collections import deque
34
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
@@ -34,17 +35,18 @@ def _sort_units(units: list, edge_sampling: str) -> list:
3435
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
3536
:return: sorted units
3637
"""
38+
default_loss = -math.log(0.1)
3739
if edge_sampling == "random":
3840
random.shuffle(units)
3941
elif edge_sampling == "min_loss":
4042
units = sorted(
4143
units,
42-
key=lambda x: x[-1]["loss"],
44+
key=lambda x: x[-1].get("loss", default_loss),
4345
)
4446
elif edge_sampling == "max_loss":
4547
units = sorted(
4648
units,
47-
key=lambda x: x[-1]["loss"],
49+
key=lambda x: x[-1].get("loss", default_loss),
4850
reverse=True,
4951
)
5052
else:
@@ -142,7 +144,7 @@ def _add_unit(u):
142144
return Community(
143145
id=seed_unit[1],
144146
nodes=list(community_nodes.keys()),
145-
edges=[tuple(sorted(e)) for e in community_edges]
147+
edges=[tuple(sorted(e)) for e in community_edges],
146148
)
147149

148150
for unit in tqdm(all_units, desc="ECE partition"):

graphgen/operators/build_kg/build_kg_service.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ def process(self, batch: pd.DataFrame) -> pd.DataFrame:
2828
docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs]
2929

3030
# consume the chunks and build kg
31-
self.build_kg(docs)
32-
return pd.DataFrame([{"status": "kg_building_completed"}])
31+
nodes, edges = self.build_kg(docs)
32+
return pd.DataFrame([{"nodes": nodes, "edges": edges}])
3333

34-
def build_kg(self, chunks: List[Chunk]) -> None:
34+
def build_kg(self, chunks: List[Chunk]) -> tuple:
3535
"""
3636
Build knowledge graph (KG) and merge into kg_instance
3737
"""
@@ -42,24 +42,34 @@ def build_kg(self, chunks: List[Chunk]) -> None:
4242
if chunk.type in ("image", "video", "table", "formula")
4343
]
4444

45+
nodes = {}
46+
edges = {}
47+
4548
if len(text_chunks) == 0:
4649
logger.info("All text chunks are already in the storage")
4750
else:
4851
logger.info("[Text Entity and Relation Extraction] processing ...")
49-
build_text_kg(
52+
text_nodes, text_edges = build_text_kg(
5053
llm_client=self.llm_client,
5154
kg_instance=self.graph_storage,
5255
chunks=text_chunks,
5356
max_loop=self.max_loop,
5457
)
58+
nodes.update(text_nodes)
59+
edges.update(text_edges)
5560
if len(mm_chunks) == 0:
5661
logger.info("All multi-modal chunks are already in the storage")
5762
else:
5863
logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
59-
build_mm_kg(
64+
mm_nodes, mm_edges = build_mm_kg(
6065
llm_client=self.llm_client,
6166
kg_instance=self.graph_storage,
6267
chunks=mm_chunks,
6368
)
69+
nodes.update(mm_nodes)
70+
edges.update(mm_edges)
6471

6572
self.graph_storage.index_done_callback()
73+
logger.info("Knowledge graph building completed.")
74+
75+
return nodes, edges

graphgen/operators/build_kg/build_mm_kg.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def build_mm_kg(
1212
llm_client: BaseLLMWrapper,
1313
kg_instance: BaseGraphStorage,
1414
chunks: List[Chunk],
15-
):
15+
) -> tuple:
1616
"""
1717
Build multi-modal KG and merge into kg_instance
1818
:param llm_client: Synthesizer LLM model to extract entities and relationships
@@ -48,3 +48,5 @@ def build_mm_kg(
4848
list(edges.items()),
4949
desc="Inserting relationships into storage",
5050
)
51+
52+
return nodes, edges

graphgen/operators/build_kg/build_text_kg.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def build_text_kg(
1313
kg_instance: BaseGraphStorage,
1414
chunks: List[Chunk],
1515
max_loop: int = 3,
16-
):
16+
) -> tuple:
1717
"""
1818
:param llm_client: Synthesizer LLM model to extract entities and relationships
1919
:param kg_instance
@@ -50,3 +50,5 @@ def build_text_kg(
5050
list(edges.items()),
5151
desc="Inserting relationships into storage",
5252
)
53+
54+
return nodes, edges

graphgen/operators/evaluate/evaluate_service.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ async def _process_single_qa(self, item: dict[str, Any]) -> dict[str, Any]:
9595
answer=str(item.get("answer", "")),
9696
)
9797
if not qa_pair.question or not qa_pair.answer:
98-
self.logger.error("Empty question or answer, skipping.")
98+
logger.error("Empty question or answer, skipping.")
9999
return {}
100100
except Exception as e:
101-
self.logger.error("Error in QAPair creation: %s", str(e))
101+
logger.error("Error in QAPair creation: %s", str(e))
102102
return {}
103103

104104
for metric, evaluator in self.qa_evaluators.items():
@@ -110,7 +110,7 @@ async def _process_single_qa(self, item: dict[str, Any]) -> dict[str, Any]:
110110
else:
111111
item[metric] = float(score)
112112
except Exception as e:
113-
self.logger.error("Error in %s evaluation: %s", metric, str(e))
113+
logger.error("Error in %s evaluation: %s", metric, str(e))
114114
item[metric] = None
115115
return item
116116

@@ -136,7 +136,7 @@ def transform_messages_format(items: list[dict]) -> list[dict]:
136136
return []
137137

138138
if not self.qa_evaluators:
139-
self.logger.warning("No QA evaluators initialized, skipping QA evaluation")
139+
logger.warning("No QA evaluators initialized, skipping QA evaluation")
140140
return []
141141

142142
items = transform_messages_format(items)
@@ -155,11 +155,11 @@ def _evaluate_kg(self) -> Dict[str, Any]:
155155

156156
for metric, evaluator in self.kg_evaluators.items():
157157
try:
158-
self.logger.info("Running %s evaluation...", metric)
158+
logger.info("Running %s evaluation...", metric)
159159
score = evaluator.evaluate()
160160
results[metric] = score
161161
except Exception as e:
162-
self.logger.error("Error in %s evaluation: %s", metric, str(e))
162+
logger.error("Error in %s evaluation: %s", metric, str(e))
163163
results[metric] = {"error": str(e)}
164164
return results
165165

graphgen/operators/partition/partition_service.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,13 @@ def partition(self) -> Iterable[pd.DataFrame]:
7979
else:
8080
raise ValueError(f"Unsupported partition method: {method}")
8181

82-
communities = partitioner.partition(g=self.kg_instance, **method_params)
82+
communities: Iterable = partitioner.partition(
83+
g=self.kg_instance, **method_params
84+
)
8385

86+
count = 0
8487
for community in communities:
88+
count += 1
8589
batch = partitioner.community2batch(community, g=self.kg_instance)
8690
batch = self._attach_additional_data_to_node(batch)
8791

@@ -91,6 +95,7 @@ def partition(self) -> Iterable[pd.DataFrame]:
9195
"edges": [batch[1]],
9296
}
9397
)
98+
logger.info("Total communities partitioned: %d", count)
9499

95100
def _pre_tokenize(self) -> None:
96101
"""Pre-tokenize all nodes and edges to add token length information."""

graphgen/operators/quiz/quiz_service.py

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def __init__(
1515
graph_backend: str = "kuzu",
1616
kv_backend: str = "rocksdb",
1717
quiz_samples: int = 1,
18-
concurrency_limit: int = 200,
1918
):
2019
super().__init__(working_dir=working_dir, op_name="quiz_service")
2120
self.quiz_samples = quiz_samples
@@ -28,21 +27,16 @@ def __init__(
2827
backend=kv_backend, working_dir=working_dir, namespace="quiz"
2928
)
3029
self.generator = QuizGenerator(self.llm_client)
31-
self.concurrency_limit = concurrency_limit
3230

3331
def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
34-
# this operator does not consume any batch data
35-
# but for compatibility we keep the interface
36-
_ = batch.to_dict(orient="records")
32+
data = batch.to_dict(orient="records")
3733
self.graph_storage.reload()
38-
yield from self.quiz()
34+
return self.quiz(data)
3935

4036
async def _process_single_quiz(self, item: tuple) -> dict | None:
4137
# if quiz in quiz_storage exists already, directly get it
4238
index, desc = item
4339
_quiz_id = compute_dict_hash({"index": index, "description": desc})
44-
if self.quiz_storage.get_by_id(_quiz_id):
45-
return None
4640

4741
tasks = []
4842
for i in range(self.quiz_samples):
@@ -68,47 +62,43 @@ async def _process_single_quiz(self, item: tuple) -> dict | None:
6862
logger.error("Error when quizzing description %s: %s", item, e)
6963
return None
7064

71-
def quiz(self) -> Iterable[pd.DataFrame]:
65+
def quiz(self, batch) -> Iterable[pd.DataFrame]:
7266
"""
7367
Get all nodes and edges and quiz their descriptions using QuizGenerator.
7468
"""
75-
edges = self.graph_storage.get_all_edges()
76-
nodes = self.graph_storage.get_all_nodes()
77-
7869
items = []
7970

80-
for edge in edges:
81-
edge_data = edge[2]
82-
desc = edge_data["description"]
83-
items.append(((edge[0], edge[1]), desc))
71+
for item in batch:
72+
nodes = item.get("nodes", [])
73+
edges = item.get("edges", [])
8474

85-
for node in nodes:
86-
node_data = node[1]
87-
desc = node_data["description"]
88-
items.append((node[0], desc))
75+
for node_id, node_data in nodes.items():
76+
node_data = node_data[0]
77+
desc = node_data["description"]
78+
items.append((node_id, desc))
79+
for edge_key, edge_data in edges.items():
80+
edge_data = edge_data[0]
81+
desc = edge_data["description"]
82+
items.append((edge_key, desc))
8983

9084
logger.info("Total descriptions to quiz: %d", len(items))
9185

92-
for i in range(0, len(items), self.concurrency_limit):
93-
batch_items = items[i : i + self.concurrency_limit]
94-
batch_results = run_concurrent(
95-
self._process_single_quiz,
96-
batch_items,
97-
desc=f"Quizzing descriptions ({i} / {i + len(batch_items)})",
98-
unit="description",
99-
)
86+
results = run_concurrent(
87+
self._process_single_quiz,
88+
items,
89+
desc=f"Quizzing batch of {len(items)} descriptions",
90+
unit="description",
91+
)
92+
valid_results = [res for res in results if res]
10093

101-
final_results = []
102-
for new_result in batch_results:
103-
if new_result:
104-
self.quiz_storage.upsert(
105-
{
106-
new_result["_quiz_id"]: {
107-
"description": new_result["description"],
108-
"quizzes": new_result["quizzes"],
109-
}
110-
}
111-
)
112-
final_results.append(new_result)
113-
self.quiz_storage.index_done_callback()
114-
yield pd.DataFrame(final_results)
94+
for res in valid_results:
95+
self.quiz_storage.upsert(
96+
{
97+
res["_quiz_id"]: {
98+
"description": res["description"],
99+
"quizzes": res["quizzes"],
100+
}
101+
}
102+
)
103+
self.quiz_storage.index_done_callback()
104+
return pd.DataFrame(valid_results)

0 commit comments

Comments
 (0)