Skip to content

Commit af530de

Browse files
wip: move storage logic to baseOperator
1 parent b735443 commit af530de

8 files changed

Lines changed: 64 additions & 71 deletions

File tree

graphgen/bases/base_operator.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
import os
33
from abc import ABC, abstractmethod
4-
from typing import Iterable, Union
4+
from typing import Iterable, Union, Tuple
55

66
import pandas as pd
77
import ray
@@ -64,12 +64,17 @@ def __call__(
6464
if to_process.empty:
6565
return
6666

67-
docs = to_process.to_dict(orient="records")
68-
result = self.process(docs)
67+
data = to_process.to_dict(orient="records")
68+
result, meta_update = self.process(data)
6969
if inspect.isgenerator(result):
70-
yield from result
70+
is_first = True
71+
for res in result:
72+
yield pd.DataFrame([res])
73+
self.store([res], meta_update if is_first else {})
74+
is_first = False
7175
else:
72-
yield result
76+
yield pd.DataFrame(result)
77+
self.store(result, meta_update)
7378
finally:
7479
CURRENT_LOGGER_VAR.reset(logger_token)
7580

@@ -130,5 +135,11 @@ def store(self, results: list, meta_update: dict):
130135
self.kv_storage.index_done_callback()
131136

132137
@abstractmethod
133-
def process(self, batch: list) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]:
134-
pass
138+
def process(self, batch: list) -> Tuple[Union[list, Iterable[list]], dict]:
139+
"""
140+
Process the input batch and return the result.
141+
:param batch
142+
:return:
143+
result: DataFrame of processed documents
144+
meta_update: dict of meta data to be updated
145+
"""

graphgen/operators/build_kg/build_kg_service.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import pandas as pd
1+
from typing import Tuple
22

33
from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator
44
from graphgen.bases.datatypes import Chunk
@@ -27,7 +27,7 @@ def __init__(
2727
self.build_kwargs = build_kwargs
2828
self.max_loop: int = int(self.build_kwargs.get("max_loop", 3))
2929

30-
def process(self, batch: list) -> pd.DataFrame:
30+
def process(self, batch: list) -> Tuple[list, dict]:
3131
"""
3232
Build knowledge graph (KG) and merge into kg_instance
3333
"""
@@ -99,8 +99,4 @@ def process(self, batch: list) -> pd.DataFrame:
9999
source_ids = edge.get("source_id", "").split("<SEP>")
100100
for source_id in source_ids:
101101
meta_updates.setdefault(source_id, []).append(str(trace_id))
102-
self.store(
103-
results,
104-
meta_updates,
105-
)
106-
return pd.DataFrame(results)
102+
return results, meta_updates

graphgen/operators/chunk/chunk_service.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import os
22
from functools import lru_cache
3-
from typing import Union
4-
5-
import pandas as pd
3+
from typing import Union, Tuple
64

75
from graphgen.bases import BaseOperator
86
from graphgen.models import (
@@ -51,7 +49,7 @@ def __init__(
5149
self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model)
5250
self.chunk_kwargs = chunk_kwargs
5351

54-
def process(self, batch: list) -> pd.DataFrame:
52+
def process(self, batch: list) -> Tuple[list, dict]:
5553
results = []
5654
meta_updates = {}
5755
for doc in batch:
@@ -72,17 +70,15 @@ def process(self, batch: list) -> pd.DataFrame:
7270
else len(text_chunk),
7371
"language": doc_language,
7472
}
75-
chunk["_trace_id"] = self.generate_trace_id(chunk)
73+
chunk["_trace_id"] = self.get_trace_id(chunk)
7674
results.append(chunk)
7775
meta_updates.setdefault(doc["_trace_id"], []).append(
7876
chunk["_trace_id"]
7977
)
8078
else:
8179
# other types of documents(images, sequences) are not chunked
82-
doc["_trace_id"] = self.generate_trace_id(doc)
83-
results.append(doc)
84-
self.store(
85-
results,
86-
meta_updates,
87-
)
88-
return pd.DataFrame(results)
80+
data = doc.copy()
81+
data["_trace_id"] = self.get_trace_id(data)
82+
results.append(data)
83+
meta_updates.setdefault(doc["_trace_id"], []).append(data["_trace_id"])
84+
return results, meta_updates

graphgen/operators/extract/extract_service.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
2-
3-
import pandas as pd
2+
from typing import Tuple
43

54
from graphgen.bases import BaseLLMWrapper, BaseOperator, Chunk
65
from graphgen.common import init_llm
@@ -26,7 +25,7 @@ def __init__(
2625
else:
2726
raise ValueError(f"Unsupported extraction method: {self.method}")
2827

29-
def process(self, batch: list) -> pd.DataFrame:
28+
def process(self, batch: list) -> Tuple[list, dict]:
3029
logger.info("Start extracting information from %d items", len(batch))
3130
chunks = [Chunk.from_dict(item["_trace_id"], item) for item in batch]
3231
results = run_concurrent(

graphgen/operators/generate/generate_service.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import pandas as pd
2-
1+
from typing import Tuple
32
from graphgen.bases import BaseKVStorage, BaseLLMWrapper, BaseOperator
43
from graphgen.common import init_llm, init_storage
54
from graphgen.utils import logger, run_concurrent
@@ -80,11 +79,9 @@ def __init__(
8079
else:
8180
raise ValueError(f"Unsupported generation mode: {method}")
8281

83-
def process(self, batch: list[dict]) -> pd.DataFrame:
82+
def process(self, batch: list) -> Tuple[list, dict]:
8483
"""
8584
Generate question-answer pairs based on nodes and edges.
86-
:param batch
87-
:return: QA pairs
8885
"""
8986
logger.info("[Generation] mode: %s, batches: %d", self.method, len(batch))
9087
triples = [(item["nodes"], item["edges"]) for item in batch]
@@ -106,11 +103,7 @@ def process(self, batch: list[dict]) -> pd.DataFrame:
106103
res = self.generator.format_generation_results(
107104
qa_pair, output_data_format=self.data_format
108105
)
109-
res["_trace_id"] = self.generate_trace_id(res)
106+
res["_trace_id"] = self.get_trace_id(res)
110107
final_results.append(res)
111108
meta_updates.setdefault(input_trace_id, []).append(res["_trace_id"])
112-
self.store(
113-
final_results,
114-
meta_updates,
115-
)
116-
return pd.DataFrame(final_results)
109+
return final_results, meta_updates

graphgen/operators/judge/judge_service.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1+
from typing import Tuple
12
import math
23

3-
import pandas as pd
4-
54
from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator
65
from graphgen.common import init_llm, init_storage
76
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
@@ -44,7 +43,7 @@ async def _process_single_judge(self, item: dict) -> dict:
4443
item["loss"] = -math.log(0.1)
4544
return item
4645

47-
def process(self, batch: list) -> pd.DataFrame:
46+
def process(self, batch: list) -> Tuple[list, dict]:
4847
"""
4948
Judge the description in the item and compute the loss.
5049
"""
@@ -78,10 +77,9 @@ def process(self, batch: list) -> pd.DataFrame:
7877
edge_data["loss"] = loss
7978
self.graph_storage.update_edge(edge_source, edge_target, edge_data)
8079

81-
result["_trace_id"] = self.generate_trace_id(result)
80+
result["_trace_id"] = self.get_trace_id(result)
8281
to_store.append(result)
8382
meta_update.setdefault(input_trace_id, []).append(result["_trace_id"])
8483
self.graph_storage.index_done_callback()
85-
self.store(to_store, meta_update)
8684

87-
return pd.DataFrame(results)
85+
return results, meta_update

graphgen/operators/partition/partition_service.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import os
2-
from typing import Iterable
3-
4-
import pandas as pd
2+
from typing import Iterable, Tuple
53

64
from graphgen.bases import BaseGraphStorage, BaseOperator, BaseTokenizer
75
from graphgen.common import init_storage
@@ -24,7 +22,9 @@ def __init__(
2422
graph_backend: str = "kuzu",
2523
**partition_kwargs,
2624
):
27-
super().__init__(working_dir=working_dir, op_name="partition")
25+
super().__init__(
26+
working_dir=working_dir, kv_backend=kv_backend, op_name="partition"
27+
)
2828
self.kg_instance: BaseGraphStorage = init_storage(
2929
backend=graph_backend,
3030
working_dir=working_dir,
@@ -55,7 +55,7 @@ def __init__(
5555
else:
5656
raise ValueError(f"Unsupported partition method: {method}")
5757

58-
def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
58+
def process(self, batch: list) -> Tuple[Iterable[list], dict]:
5959
# this operator does not consume any batch data
6060
# but for compatibility we keep the interface
6161
self.kg_instance.reload()
@@ -64,19 +64,22 @@ def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
6464
g=self.kg_instance, **self.method_params
6565
)
6666

67-
count = 0
68-
for community in communities:
69-
count += 1
70-
batch = self.partitioner.community2batch(community, g=self.kg_instance)
71-
# batch = self._attach_additional_data_to_node(batch)
67+
def generator():
68+
count = 0
69+
for community in communities:
70+
count += 1
71+
batch = self.partitioner.community2batch(community, g=self.kg_instance)
72+
# batch = self._attach_additional_data_to_node(batch)
73+
74+
result = {
75+
"nodes": batch[0],
76+
"edges": batch[1],
77+
}
78+
result["_trace_id"] = self.get_trace_id(result)
79+
yield result
80+
logger.info("Total communities partitioned: %d", count)
7281

73-
result = {
74-
"nodes": batch[0],
75-
"edges": batch[1],
76-
}
77-
result["_trace_id"] = self.generate_trace_id(result)
78-
yield pd.DataFrame([result])
79-
logger.info("Total communities partitioned: %d", count)
82+
return generator(), {}
8083

8184
# def _attach_additional_data_to_node(self, batch: tuple) -> tuple:
8285
# """

graphgen/operators/quiz/quiz_service.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import pandas as pd
1+
from typing import Tuple
22

33
from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator
44
from graphgen.common import init_llm, init_storage
@@ -50,7 +50,7 @@ async def _process_single_quiz(self, item: tuple) -> dict | None:
5050
logger.error("Error when quizzing description %s: %s", item, e)
5151
return None
5252

53-
def process(self, batch: list) -> pd.DataFrame:
53+
def process(self, batch: list) -> Tuple[list, dict]:
5454
"""
5555
Get all nodes and edges and quiz their descriptions using QuizGenerator.
5656
"""
@@ -67,7 +67,7 @@ def process(self, batch: list) -> pd.DataFrame:
6767
edge_key = (edge["src_id"], edge["tgt_id"])
6868
items.append((input_id, edge["description"], edge_key))
6969
if not items:
70-
return pd.DataFrame()
70+
return [], {}
7171

7272
logger.info("Total descriptions to quiz: %d", len(items))
7373
results = run_concurrent(
@@ -83,11 +83,8 @@ def process(self, batch: list) -> pd.DataFrame:
8383
for (input_id, _, _), quiz_data in zip(items, results):
8484
if quiz_data is None:
8585
continue
86-
quiz_data["_trace_id"] = self.generate_trace_id(quiz_data)
86+
quiz_data["_trace_id"] = self.get_trace_id(quiz_data)
8787
final_results.append(quiz_data)
8888
meta_update[input_id] = [quiz_data["_trace_id"]]
8989

90-
if final_results:
91-
self.store(final_results, meta_update)
92-
93-
return pd.DataFrame(final_results)
90+
return final_results, meta_update

0 commit comments

Comments
 (0)