Skip to content

Commit 3688560

Browse files
perf: move pre-tokenize from partition to build_kg to reduce memory (#156)
* perf: move pre-tokenize from partition to build_kg to reduce memory * Potential fix for pull request finding 'Commented-out code' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> * fix: use '' for an unknown node's description’ --------- Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
1 parent d1b3b6e commit 3688560

6 files changed

Lines changed: 23 additions & 55 deletions

File tree

graphgen/bases/base_storage.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_node(self, node_id: str) -> Union[dict, None]:
7979
raise NotImplementedError
8080

8181
@abstractmethod
82-
def update_node(self, node_id: str, node_data: dict[str, str]):
82+
def update_node(self, node_id: str, node_data: dict[str, any]):
8383
raise NotImplementedError
8484

8585
@abstractmethod
@@ -96,7 +96,7 @@ def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None
9696

9797
@abstractmethod
9898
def update_edge(
99-
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
99+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
100100
):
101101
raise NotImplementedError
102102

@@ -113,12 +113,12 @@ def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], No
113113
raise NotImplementedError
114114

115115
@abstractmethod
116-
def upsert_node(self, node_id: str, node_data: dict[str, str]):
116+
def upsert_node(self, node_id: str, node_data: dict[str, any]):
117117
raise NotImplementedError
118118

119119
@abstractmethod
120120
def upsert_edge(
121-
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
121+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
122122
):
123123
raise NotImplementedError
124124

graphgen/models/kg_builder/light_rag_kg_builder.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
1818
def __init__(self, llm_client: BaseLLMWrapper, max_loop: int = 3):
1919
super().__init__(llm_client)
2020
self.max_loop = max_loop
21+
self.tokenizer = llm_client.tokenizer
2122

2223
async def extract(
2324
self, chunk: Chunk
@@ -134,6 +135,7 @@ async def merge_nodes(
134135
"entity_name": entity_name,
135136
"description": description,
136137
"source_id": source_id,
138+
"length": self.tokenizer.count_tokens(description),
137139
}
138140
kg_instance.upsert_node(entity_name, node_data=node_data)
139141
return node_data
@@ -167,9 +169,11 @@ async def merge_edges(
167169
kg_instance.upsert_node(
168170
insert_id,
169171
node_data={
170-
"source_id": source_id,
171-
"description": description,
172172
"entity_type": "UNKNOWN",
173+
"entity_name": insert_id,
174+
"description": "",
175+
"source_id": source_id,
176+
"length": self.tokenizer.count_tokens(description),
173177
},
174178
)
175179

@@ -182,12 +186,13 @@ async def merge_edges(
182186
"tgt_id": tgt_id,
183187
"description": description,
184188
"source_id": source_id, # for traceability
189+
"length": self.tokenizer.count_tokens(description),
185190
}
186191

187192
kg_instance.upsert_edge(
188193
src_id,
189194
tgt_id,
190-
edge_data={"source_id": source_id, "description": description},
195+
edge_data=edge_data,
191196
)
192197
return edge_data
193198

graphgen/models/partitioner/ece_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _add_unit(u):
9999
return False
100100
community_edges[i] = d
101101
used_e.add(i)
102-
token_sum += d.get("length", 0)
102+
token_sum += int(d.get("length", 0))
103103
return True
104104

105105
_add_unit(seed_unit)

graphgen/models/storage/graph/kuzu_storage.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def get_node(self, node_id: str) -> Any:
215215
data_str = result.get_next()[0]
216216
return self._safe_json_loads(data_str)
217217

218-
def update_node(self, node_id: str, node_data: dict[str, str]):
218+
def update_node(self, node_id: str, node_data: dict[str, any]):
219219
current_data = self.get_node(node_id)
220220
if current_data is None:
221221
print(f"Node {node_id} not found for update.")
@@ -263,7 +263,7 @@ def get_edge(self, source_node_id: str, target_node_id: str):
263263
return self._safe_json_loads(data_str)
264264

265265
def update_edge(
266-
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
266+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
267267
):
268268
current_data = self.get_edge(source_node_id, target_node_id)
269269
if current_data is None:
@@ -318,7 +318,7 @@ def get_node_edges(self, source_node_id: str) -> Any:
318318
edges.append((src, dst, data))
319319
return edges
320320

321-
def upsert_node(self, node_id: str, node_data: dict[str, str]):
321+
def upsert_node(self, node_id: str, node_data: dict[str, any]):
322322
"""
323323
Insert or Update node.
324324
Kuzu supports MERGE clause (similar to Neo4j) to handle upserts.
@@ -336,7 +336,7 @@ def upsert_node(self, node_id: str, node_data: dict[str, str]):
336336
self._conn.execute(query, {"id": node_id, "data": json_data})
337337

338338
def upsert_edge(
339-
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
339+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
340340
):
341341
"""
342342
Insert or Update edge.

graphgen/models/storage/graph/networkx_storage.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,22 +144,22 @@ def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], No
144144
def get_graph(self) -> nx.Graph:
145145
return self._graph
146146

147-
def upsert_node(self, node_id: str, node_data: dict[str, str]):
147+
def upsert_node(self, node_id: str, node_data: dict[str, any]):
148148
self._graph.add_node(node_id, **node_data)
149149

150-
def update_node(self, node_id: str, node_data: dict[str, str]):
150+
def update_node(self, node_id: str, node_data: dict[str, any]):
151151
if self._graph.has_node(node_id):
152152
self._graph.nodes[node_id].update(node_data)
153153
else:
154154
print(f"Node {node_id} not found in the graph for update.")
155155

156156
def upsert_edge(
157-
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
157+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
158158
):
159159
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
160160

161161
def update_edge(
162-
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
162+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
163163
):
164164
if self._graph.has_edge(source_node_id, target_node_id):
165165
self._graph.edges[(source_node_id, target_node_id)].update(edge_data)

graphgen/operators/partition/partition_service.py

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,8 @@ def partition(self) -> Iterable[pd.DataFrame]:
6060
partitioner = DFSPartitioner()
6161
elif method == "ece":
6262
logger.info("Partitioning knowledge graph using ECE method.")
63-
# TODO: before ECE partitioning, we need to:
64-
# 1. 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random
65-
# 2. pre-tokenize nodes and edges to get the token length
66-
self._pre_tokenize()
63+
# before ECE partitioning, we need to:
64+
# 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random
6765
partitioner = ECEPartitioner()
6866
elif method == "leiden":
6967
logger.info("Partitioning knowledge graph using Leiden method.")
@@ -97,41 +95,6 @@ def partition(self) -> Iterable[pd.DataFrame]:
9795
)
9896
logger.info("Total communities partitioned: %d", count)
9997

100-
def _pre_tokenize(self) -> None:
101-
"""Pre-tokenize all nodes and edges to add token length information."""
102-
logger.info("Starting pre-tokenization of nodes and edges...")
103-
104-
nodes = self.kg_instance.get_all_nodes()
105-
edges = self.kg_instance.get_all_edges()
106-
107-
# Process nodes
108-
for node_id, node_data in nodes:
109-
if "length" not in node_data:
110-
try:
111-
description = node_data.get("description", "")
112-
tokens = self.tokenizer_instance.encode(description)
113-
node_data["length"] = len(tokens)
114-
self.kg_instance.update_node(node_id, node_data)
115-
except Exception as e:
116-
logger.warning("Failed to tokenize node %s: %s", node_id, e)
117-
node_data["length"] = 0
118-
119-
# Process edges
120-
for u, v, edge_data in edges:
121-
if "length" not in edge_data:
122-
try:
123-
description = edge_data.get("description", "")
124-
tokens = self.tokenizer_instance.encode(description)
125-
edge_data["length"] = len(tokens)
126-
self.kg_instance.update_edge(u, v, edge_data)
127-
except Exception as e:
128-
logger.warning("Failed to tokenize edge %s-%s: %s", u, v, e)
129-
edge_data["length"] = 0
130-
131-
# Persist changes
132-
self.kg_instance.index_done_callback()
133-
logger.info("Pre-tokenization completed.")
134-
13598
def _attach_additional_data_to_node(self, batch: tuple) -> tuple:
13699
"""
137100
Attach additional data from chunk_storage to nodes in the batch.

0 commit comments

Comments
 (0)