Skip to content

Commit 7c9240a

Browse files
refactor: refactor evaluators
1 parent af530de commit 7c9240a

32 files changed

Lines changed: 587 additions & 1477 deletions

examples/evaluate/evaluate_kg/kg_evaluation_config.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ nodes:
1010
dependencies: []
1111
params:
1212
input_path:
13-
- examples/input_examples/extract_demo.txt
13+
- examples/input_examples/jsonl_demo.jsonl
1414

1515
- id: chunk
1616
op_name: chunk
@@ -39,7 +39,6 @@ nodes:
3939
dependencies:
4040
- build_kg
4141
params:
42+
target: kg
4243
metrics:
43-
- kg_structure
44-
- kg_accuracy
45-
- kg_consistency
44+
- structure

examples/evaluate/evaluate_qa/qa_evaluation_config.yaml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
global_params:
22
working_dir: cache
3-
graph_backend: kuzu # graph database backend, support: kuzu, networkx
4-
kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv
3+
graph_backend: networkx # graph database backend, support: kuzu, networkx
4+
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv
55

66
nodes:
77
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
@@ -89,10 +89,11 @@ nodes:
8989
batch_size: 128
9090
save_output: true
9191
params:
92+
target: qa
9293
metrics:
93-
- qa_length
94-
- qa_mtld
95-
# - qa_reward_score
96-
# - qa_uni_score
94+
- length
95+
- mtld
96+
# - reward_score
97+
# - uni_score
9798
mtld_params:
9899
threshold: 0.7
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/evaluate/evaluate_triple/triple_evaluation_config.yaml
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
global_params:
2+
working_dir: cache
3+
graph_backend: networkx # graph database backend, support: kuzu, networkx
4+
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv
5+
6+
nodes:
7+
- id: read
8+
op_name: read
9+
type: source
10+
dependencies: []
11+
params:
12+
input_path:
13+
- examples/input_examples/jsonl_demo.jsonl
14+
15+
- id: chunk
16+
op_name: chunk
17+
type: map_batch
18+
dependencies:
19+
- read
20+
execution_params:
21+
replicas: 4
22+
params:
23+
chunk_size: 20480 # larger chunk size for better context
24+
chunk_overlap: 2000
25+
26+
- id: build_kg
27+
op_name: build_kg
28+
type: map_batch
29+
dependencies:
30+
- chunk
31+
execution_params:
32+
replicas: 1
33+
batch_size: 128
34+
35+
- id: evaluate
36+
op_name: evaluate
37+
type: aggregate
38+
save_output: true
39+
dependencies:
40+
- build_kg
41+
params:
42+
target: triple
43+
src_namespace: chunk
44+
tgt_namespace: build_kg
45+
metrics:
46+
- accuracy

graphgen/bases/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .base_evaluator import BaseKGEvaluator, BaseQAEvaluator, BaseTripleEvaluator
12
from .base_extractor import BaseExtractor
23
from .base_generator import BaseGenerator
34
from .base_kg_builder import BaseKGBuilder
@@ -9,5 +10,4 @@
910
from .base_splitter import BaseSplitter
1011
from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
1112
from .base_tokenizer import BaseTokenizer
12-
from .base_evaluator import BaseEvaluator
1313
from .datatypes import Chunk, Config, Node, QAPair, Token

graphgen/bases/base_evaluator.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,29 @@
11
from abc import ABC, abstractmethod
2+
from typing import Any
3+
4+
from .base_storage import BaseGraphStorage
25
from .datatypes import QAPair
36

47

5-
class BaseEvaluator(ABC):
8+
class BaseQAEvaluator(ABC):
69
@abstractmethod
7-
def evaluate(self, pair: QAPair) -> float:
10+
async def evaluate(self, pair: QAPair) -> dict[str, float]:
811
"""
912
Evaluate the text and return a score.
1013
"""
14+
15+
16+
class BaseKGEvaluator(ABC):
17+
@abstractmethod
18+
def evaluate(self, kg: BaseGraphStorage) -> dict[str, Any]:
19+
"""
20+
Evaluate the whole graph and return a dict of scores.
21+
"""
22+
23+
24+
class BaseTripleEvaluator(ABC):
25+
@abstractmethod
26+
async def evaluate(self, unit: dict) -> dict[str, float]:
27+
"""
28+
Evaluate a node/edge and return a score.
29+
"""

graphgen/bases/base_operator.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
11
import inspect
22
import os
33
from abc import ABC, abstractmethod
4-
from typing import Iterable, Union, Tuple
4+
from typing import Iterable, Tuple, Union
55

6+
import numpy as np
67
import pandas as pd
78
import ray
89

910

11+
def convert_to_serializable(obj):
12+
if isinstance(obj, np.ndarray):
13+
return obj.tolist()
14+
if isinstance(obj, np.generic):
15+
return obj.item()
16+
if isinstance(obj, dict):
17+
return {k: convert_to_serializable(v) for k, v in obj.items()}
18+
if isinstance(obj, list):
19+
return [convert_to_serializable(v) for v in obj]
20+
return obj
21+
22+
1023
class BaseOperator(ABC):
1124
def __init__(
1225
self,
@@ -21,6 +34,7 @@ def __init__(
2134
log_dir = os.path.join(working_dir, "logs")
2235
self.op_name = op_name or self.__class__.__name__
2336
self.working_dir = working_dir
37+
self.kv_backend = kv_backend
2438
self.kv_storage = init_storage(
2539
backend=kv_backend, working_dir=working_dir, namespace=self.op_name
2640
)
@@ -118,6 +132,9 @@ def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
118132
return to_process, pd.DataFrame(recovered_chunks)
119133

120134
def store(self, results: list, meta_update: dict):
135+
results = convert_to_serializable(results)
136+
meta_update = convert_to_serializable(meta_update)
137+
121138
batch = {res["_trace_id"]: res for res in results}
122139
self.kv_storage.upsert(batch)
123140

graphgen/bases/datatypes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ class QAPair:
3131
question: str
3232
answer: str
3333

34+
@staticmethod
35+
def from_dict(data: dict) -> "QAPair":
36+
return QAPair(
37+
question=data.get("question", ""),
38+
answer=data.get("answer", ""),
39+
)
40+
3441

3542
@dataclass
3643
class Token:

graphgen/models/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .evaluator import (
22
AccuracyEvaluator,
3-
ConsistencyEvaluator,
43
LengthEvaluator,
54
MTLDEvaluator,
65
RewardEvaluator,
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from .kg import AccuracyEvaluator, ConsistencyEvaluator, StructureEvaluator
1+
from .kg import StructureEvaluator
22
from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
3+
from .triple import AccuracyEvaluator

0 commit comments

Comments
 (0)