Skip to content

Commit b735443

Browse files
wip: add checkpoint
1 parent 857a16c commit b735443

13 files changed

Lines changed: 378 additions & 291 deletions

File tree

graphgen/bases/base_generator.py

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -42,46 +42,36 @@ async def generate(
4242

4343
@staticmethod
4444
def format_generation_results(
45-
results: list[dict], output_data_format: str
46-
) -> list[dict[str, Any]]:
45+
result: dict, output_data_format: str
46+
) -> dict[str, Any]:
47+
question = result.get("question", "")
48+
answer = result.get("answer", "")
49+
if "options" in result and result["options"]:
50+
options = result["options"]
51+
options_str = "\n".join(
52+
[f"{key}. {options[key]}" for key in sorted(options.keys())]
53+
)
54+
question += f"\nOptions:\n{options_str}"
4755

48-
flat_results = []
49-
for qa_data in results:
50-
question = qa_data.get("question", "")
51-
answer = qa_data.get("answer", "")
52-
if "options" in qa_data and qa_data["options"]:
53-
options = qa_data["options"]
54-
options_str = "\n".join(
55-
[f"{key}. {options[key]}" for key in sorted(options.keys())]
56-
)
57-
question += f"\nOptions:\n{options_str}"
56+
if output_data_format == "Alpaca":
57+
return {
58+
"instruction": question,
59+
"input": "",
60+
"output": answer,
61+
}
5862

59-
if output_data_format == "Alpaca":
60-
flat_results.append(
61-
{
62-
"instruction": question,
63-
"input": "",
64-
"output": answer,
65-
}
66-
)
67-
elif output_data_format == "Sharegpt":
68-
flat_results.append(
69-
{
70-
"conversations": [
71-
{"from": "human", "value": question},
72-
{"from": "gpt", "value": answer},
73-
]
74-
}
75-
)
76-
elif output_data_format == "ChatML":
77-
flat_results.append(
78-
{
79-
"messages": [
80-
{"role": "user", "content": question},
81-
{"role": "assistant", "content": answer},
82-
]
83-
}
84-
)
85-
else:
86-
raise ValueError(f"Unknown output data format: {output_data_format}")
87-
return flat_results
63+
if output_data_format == "Sharegpt":
64+
return {
65+
"conversations": [
66+
{"from": "human", "value": question},
67+
{"from": "gpt", "value": answer},
68+
]
69+
}
70+
if output_data_format == "ChatML":
71+
return {
72+
"messages": [
73+
{"role": "user", "content": question},
74+
{"role": "assistant", "content": answer},
75+
]
76+
}
77+
raise ValueError(f"Unknown output data format: {output_data_format}")

graphgen/bases/base_operator.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,22 @@
88

99

1010
class BaseOperator(ABC):
11-
def __init__(self, working_dir: str = "cache", op_name: str = None):
11+
def __init__(
12+
self,
13+
working_dir: str = "cache",
14+
kv_backend: str = "rocksdb",
15+
op_name: str = None,
16+
):
1217
# lazy import to avoid circular import
18+
from graphgen.common import init_storage
1319
from graphgen.utils import set_logger
1420

1521
log_dir = os.path.join(working_dir, "logs")
1622
self.op_name = op_name or self.__class__.__name__
23+
self.working_dir = working_dir
24+
self.kv_storage = init_storage(
25+
backend=kv_backend, working_dir=working_dir, namespace=self.op_name
26+
)
1727

1828
try:
1929
ctx = ray.get_runtime_context()
@@ -45,17 +55,80 @@ def __call__(
4555

4656
logger_token = CURRENT_LOGGER_VAR.set(self.logger)
4757
try:
48-
result = self.process(batch)
58+
self.kv_storage.reload()
59+
to_process, recovered = self.split(batch)
60+
# yield recovered chunks first
61+
if not recovered.empty:
62+
yield recovered
63+
64+
if to_process.empty:
65+
return
66+
67+
docs = to_process.to_dict(orient="records")
68+
result = self.process(docs)
4969
if inspect.isgenerator(result):
5070
yield from result
5171
else:
5272
yield result
5373
finally:
5474
CURRENT_LOGGER_VAR.reset(logger_token)
5575

56-
@abstractmethod
57-
def process(self, batch):
58-
raise NotImplementedError("Subclasses must implement the process method.")
59-
6076
def get_logger(self):
6177
return self.logger
78+
79+
def get_meta_forward(self):
80+
return self.kv_storage.get_by_id("_meta_forward") or {}
81+
82+
def get_meta_inverse(self):
83+
return self.kv_storage.get_by_id("_meta_inverse") or {}
84+
85+
def get_trace_id(self, content: dict) -> str:
86+
from graphgen.utils import compute_dict_hash
87+
88+
return compute_dict_hash(content, prefix=f"{self.op_name}-")
89+
90+
def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
91+
"""
92+
Split the input batch into to_process & processed based on _meta data in KV_storage
93+
:param batch
94+
:return:
95+
to_process: DataFrame of documents to be chunked
96+
recovered: Result DataFrame of already chunked documents
97+
"""
98+
meta_forward = self.get_meta_forward()
99+
meta_ids = set(meta_forward.keys())
100+
mask = batch["_trace_id"].isin(meta_ids)
101+
to_process = batch[~mask]
102+
processed = batch[mask]
103+
104+
if processed.empty:
105+
return to_process, pd.DataFrame()
106+
107+
all_ids = [
108+
pid for tid in processed["_trace_id"] for pid in meta_forward.get(tid, [])
109+
]
110+
111+
recovered_chunks = self.kv_storage.get_by_ids(all_ids)
112+
recovered_chunks = [c for c in recovered_chunks if c is not None]
113+
return to_process, pd.DataFrame(recovered_chunks)
114+
115+
def store(self, results: list, meta_update: dict):
116+
batch = {res["_trace_id"]: res for res in results}
117+
self.kv_storage.upsert(batch)
118+
119+
# update forward meta
120+
forward_meta = self.get_meta_forward()
121+
forward_meta.update(meta_update)
122+
self.kv_storage.update({"_meta_forward": forward_meta})
123+
124+
# update inverse meta
125+
inverse_meta = self.get_meta_inverse()
126+
for k, v_list in meta_update.items():
127+
for v in v_list:
128+
inverse_meta[v] = k
129+
self.kv_storage.update({"_meta_inverse": inverse_meta})
130+
self.kv_storage.index_done_callback()
131+
132+
@abstractmethod
133+
def process(self, batch: list) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]:
134+
pass

graphgen/models/extractor/schema_guided_extractor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from typing import Dict, List
33

4-
from graphgen.bases import BaseExtractor, BaseLLMWrapper
4+
from graphgen.bases import BaseExtractor, BaseLLMWrapper, Chunk
55
from graphgen.templates import SCHEMA_GUIDED_EXTRACTION_PROMPT
66
from graphgen.utils import compute_dict_hash, detect_main_language, logger
77

@@ -59,9 +59,9 @@ def build_prompt(self, text: str) -> str:
5959
)
6060
return prompt
6161

62-
async def extract(self, chunk: dict) -> dict:
62+
async def extract(self, chunk: Chunk) -> dict:
6363
_chunk_id = chunk.get("_chunk_id", "")
64-
text = chunk.get("content", "")
64+
text = chunk.content
6565

6666
prompt = self.build_prompt(text)
6767
response = await self.llm_client.generate_answer(prompt)

graphgen/operators/build_kg/build_kg_service.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import List
2-
31
import pandas as pd
42

53
from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator
@@ -13,31 +11,27 @@
1311

1412
class BuildKGService(BaseOperator):
1513
def __init__(
16-
self, working_dir: str = "cache", graph_backend: str = "kuzu", **build_kwargs
14+
self,
15+
working_dir: str = "cache",
16+
kv_backend: str = "rocksdb",
17+
graph_backend: str = "kuzu",
18+
**build_kwargs
1719
):
18-
super().__init__(working_dir=working_dir, op_name="build_kg_service")
20+
super().__init__(
21+
working_dir=working_dir, kv_backend=kv_backend, op_name="build_kg"
22+
)
1923
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
2024
self.graph_storage: BaseGraphStorage = init_storage(
2125
backend=graph_backend, working_dir=working_dir, namespace="graph"
2226
)
2327
self.build_kwargs = build_kwargs
2428
self.max_loop: int = int(self.build_kwargs.get("max_loop", 3))
2529

26-
def process(self, batch: pd.DataFrame) -> pd.DataFrame:
27-
docs = batch.to_dict(orient="records")
28-
docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs]
29-
30-
# consume the chunks and build kg
31-
nodes, edges = self.build_kg(docs)
32-
return pd.DataFrame(
33-
[{"node": node, "edge": []} for node in nodes]
34-
+ [{"node": [], "edge": edge} for edge in edges]
35-
)
36-
37-
def build_kg(self, chunks: List[Chunk]) -> tuple:
30+
def process(self, batch: list) -> pd.DataFrame:
3831
"""
3932
Build knowledge graph (KG) and merge into kg_instance
4033
"""
34+
chunks = [Chunk.from_dict(doc["_trace_id"], doc) for doc in batch]
4135
text_chunks = [chunk for chunk in chunks if chunk.type == "text"]
4236
mm_chunks = [
4337
chunk
@@ -75,4 +69,38 @@ def build_kg(self, chunks: List[Chunk]) -> tuple:
7569
self.graph_storage.index_done_callback()
7670
logger.info("Knowledge graph building completed.")
7771

78-
return nodes, edges
72+
meta_updates = {}
73+
results = []
74+
for node in nodes:
75+
if not node:
76+
continue
77+
trace_id = node["entity_name"]
78+
results.append(
79+
{
80+
"_trace_id": trace_id,
81+
"node": node,
82+
"edge": {},
83+
}
84+
)
85+
source_ids = node.get("source_id", "").split("<SEP>")
86+
for source_id in source_ids:
87+
meta_updates.setdefault(source_id, []).append(trace_id)
88+
for edge in edges:
89+
if not edge:
90+
continue
91+
trace_id = frozenset((edge["src_id"], edge["tgt_id"]))
92+
results.append(
93+
{
94+
"_trace_id": str(trace_id),
95+
"node": {},
96+
"edge": edge,
97+
}
98+
)
99+
source_ids = edge.get("source_id", "").split("<SEP>")
100+
for source_id in source_ids:
101+
meta_updates.setdefault(source_id, []).append(str(trace_id))
102+
self.store(
103+
results,
104+
meta_updates,
105+
)
106+
return pd.DataFrame(results)

graphgen/operators/build_kg/build_text_kg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def build_text_kg(
3030
desc="[2/4]Extracting entities and relationships from chunks",
3131
unit="chunk",
3232
)
33+
results = [res for res in results if res]
3334

3435
nodes = defaultdict(list)
3536
edges = defaultdict(list)

0 commit comments

Comments
 (0)