Skip to content

Commit 4c28c9a

Browse files
merge
2 parents edeb131 + 661211e commit 4c28c9a

11 files changed

Lines changed: 282 additions & 57 deletions

File tree

examples/filter/filter.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/filter/filter_config.yaml

examples/filter/filter_config.yaml

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
global_params:
2+
working_dir: cache
3+
graph_backend: networkx # graph database backend, support: kuzu, networkx
4+
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv
5+
6+
nodes:
7+
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
8+
op_name: read
9+
type: source
10+
dependencies: []
11+
params:
12+
input_path:
13+
- examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples
14+
15+
- id: chunk_documents
16+
op_name: chunk
17+
type: map_batch
18+
dependencies:
19+
- read_files
20+
execution_params:
21+
replicas: 4
22+
params:
23+
chunk_size: 1024 # chunk size for text splitting
24+
chunk_overlap: 100 # chunk overlap for text splitting
25+
26+
- id: build_kg
27+
op_name: build_kg
28+
type: map_batch
29+
dependencies:
30+
- chunk_documents
31+
execution_params:
32+
replicas: 1
33+
batch_size: 128
34+
35+
- id: quiz
36+
op_name: quiz
37+
type: aggregate
38+
dependencies:
39+
- build_kg
40+
execution_params:
41+
replicas: 1
42+
batch_size: 128
43+
params:
44+
quiz_samples: 2 # number of quiz samples to generate
45+
concurrency_limit: 200
46+
47+
- id: judge
48+
op_name: judge
49+
type: map_batch
50+
dependencies:
51+
- quiz
52+
execution_params:
53+
replicas: 1
54+
batch_size: 128
55+
56+
- id: partition
57+
op_name: partition
58+
type: aggregate
59+
dependencies:
60+
- judge
61+
params:
62+
method: ece # ece is a custom partition method based on comprehension loss
63+
method_params:
64+
max_units_per_community: 20 # max nodes and edges per community
65+
min_units_per_community: 5 # min nodes and edges per community
66+
max_tokens_per_community: 10240 # max tokens per community
67+
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
68+
69+
- id: generate
70+
op_name: generate
71+
type: map_batch
72+
dependencies:
73+
- partition
74+
execution_params:
75+
replicas: 1
76+
batch_size: 128
77+
save_output: true
78+
params:
79+
method: aggregated # atomic, aggregated, multi_hop, cot, vqa
80+
data_format: ChatML # Alpaca, Sharegpt, ChatML
81+
82+
- id: evaluate
83+
op_name: evaluate
84+
type: map_batch
85+
dependencies:
86+
- generate
87+
execution_params:
88+
replicas: 1
89+
batch_size: 128
90+
save_output: true
91+
params:
92+
target: qa
93+
metrics:
94+
- length
95+
- mtld
96+
# - reward_score
97+
# - uni_score
98+
mtld_params:
99+
threshold: 0.7
100+
101+
- id: filter
102+
op_name: filter
103+
type: filter
104+
dependencies:
105+
- evaluate
106+
execution_params:
107+
replicas: 1
108+
batch_size: 128
109+
save_output: true
110+
params:
111+
method: range
112+
method_params:
113+
metric: mtld
114+
min_val: 300
115+
max_val: 400
116+

graphgen/bases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .base_extractor import BaseExtractor
2+
from .base_filter import BaseValueFilter
23
from .base_generator import BaseGenerator
34
from .base_kg_builder import BaseKGBuilder
45
from .base_llm_wrapper import BaseLLMWrapper

graphgen/bases/base_filter.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Union
3+
4+
import numpy as np
5+
6+
7+
class BaseFilter(ABC):
8+
@abstractmethod
9+
def filter(self, data: Any) -> bool:
10+
"""
11+
Filter the data and return True if it passes the filter, False otherwise.
12+
"""
13+
raise NotImplementedError
14+
15+
16+
class BaseValueFilter(BaseFilter, ABC):
17+
@abstractmethod
18+
def filter(self, data: Union[int, float, np.number]) -> bool:
19+
"""
20+
Filter the numeric value and return True if it passes the filter, False otherwise.
21+
"""
22+
raise NotImplementedError
23+
24+
@property
25+
@abstractmethod
26+
def filter_type(self) -> str:
27+
"""
28+
Return the type of filter (e.g., "greater_than", "less_than", etc.)
29+
"""
30+
raise NotImplementedError

graphgen/engine.py

Lines changed: 39 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import logging
33
import os
44
from collections import defaultdict, deque
5-
from functools import wraps
65
from typing import Any, Callable, Dict, List, Set
76

87
import ray
@@ -103,7 +102,6 @@ def _scan_storage_requirements(self) -> tuple[set[str], set[str]]:
103102
kv_namespaces = set()
104103
graph_namespaces = set()
105104

106-
# TODO: Temporarily hard-coded; node storage will be centrally managed later.
107105
for node in self.config.nodes:
108106
op_name = node.op_name
109107
if self._function_needs_param(op_name, "kv_backend"):
@@ -232,62 +230,38 @@ def _filter_kwargs(
232230

233231
input_ds = self._get_input_dataset(node, initial_ds)
234232

235-
if inspect.isclass(op_handler):
236-
execution_params = node.execution_params or {}
237-
replicas = execution_params.get("replicas", 1)
238-
batch_size = (
239-
int(execution_params.get("batch_size"))
240-
if "batch_size" in execution_params
241-
else "default"
233+
# if inspect.isclass(op_handler):
234+
execution_params = node.execution_params or {}
235+
replicas = execution_params.get("replicas", 1)
236+
batch_size = (
237+
int(execution_params.get("batch_size"))
238+
if "batch_size" in execution_params
239+
else "default"
240+
)
241+
compute_resources = execution_params.get("compute_resources", {})
242+
243+
if node.type == "aggregate":
244+
self.datasets[node.id] = input_ds.repartition(1).map_batches(
245+
op_handler,
246+
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1),
247+
batch_size=None, # aggregate processes the whole dataset at once
248+
num_gpus=compute_resources.get("num_gpus", 0)
249+
if compute_resources
250+
else 0,
251+
fn_constructor_kwargs=node_params,
252+
batch_format="pandas",
242253
)
243-
compute_resources = execution_params.get("compute_resources", {})
244-
245-
if node.type == "aggregate":
246-
self.datasets[node.id] = input_ds.repartition(1).map_batches(
247-
op_handler,
248-
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1),
249-
batch_size=None, # aggregate processes the whole dataset at once
250-
num_gpus=compute_resources.get("num_gpus", 0)
251-
if compute_resources
252-
else 0,
253-
fn_constructor_kwargs=node_params,
254-
batch_format="pandas",
255-
)
256-
else:
257-
# others like map, filter, flatmap, map_batch let actors process data inside batches
258-
self.datasets[node.id] = input_ds.map_batches(
259-
op_handler,
260-
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas),
261-
batch_size=batch_size,
262-
num_gpus=compute_resources.get("num_gpus", 0)
263-
if compute_resources
264-
else 0,
265-
fn_constructor_kwargs=node_params,
266-
batch_format="pandas",
267-
)
268-
269254
else:
270-
271-
@wraps(op_handler)
272-
def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
273-
return op_handler(row_or_batch, **node_params)
274-
275-
if node.type == "map":
276-
self.datasets[node.id] = input_ds.map(func_wrapper)
277-
elif node.type == "filter":
278-
self.datasets[node.id] = input_ds.filter(func_wrapper)
279-
elif node.type == "flatmap":
280-
self.datasets[node.id] = input_ds.flat_map(func_wrapper)
281-
elif node.type == "aggregate":
282-
self.datasets[node.id] = input_ds.repartition(1).map_batches(
283-
func_wrapper, batch_format="default"
284-
)
285-
elif node.type == "map_batch":
286-
self.datasets[node.id] = input_ds.map_batches(func_wrapper)
287-
else:
288-
raise ValueError(
289-
f"Unsupported node type {node.type} for node {node.id}"
290-
)
255+
self.datasets[node.id] = input_ds.map_batches(
256+
op_handler,
257+
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas),
258+
batch_size=batch_size,
259+
num_gpus=compute_resources.get("num_gpus", 0)
260+
if compute_resources
261+
else 0,
262+
fn_constructor_kwargs=node_params,
263+
batch_format="pandas",
264+
)
291265

292266
def execute(
293267
self, initial_ds: ray.data.Dataset, output_dir: str
@@ -315,6 +289,14 @@ def execute(
315289
logger.info("Node %s output saved to %s", node.id, node_output_path)
316290

317291
# ray will lazy read the dataset
318-
self.datasets[node.id] = ray.data.read_json(node_output_path)
292+
if os.path.exists(node_output_path) and os.listdir(node_output_path):
293+
self.datasets[node.id] = ray.data.read_json(node_output_path)
294+
else:
295+
self.datasets[node.id] = ray.data.from_items([])
296+
logger.warning(
297+
"Node %s output path %s is empty. Created an empty dataset.",
298+
node.id,
299+
node_output_path,
300+
)
319301

320302
return self.datasets

graphgen/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
StructureEvaluator,
77
UniEvaluator,
88
)
9+
from .filter import RangeFilter
910
from .generator import (
1011
AggregatedGenerator,
1112
AtomicGenerator,

graphgen/models/filter/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .range_filter import RangeFilter
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Union
2+
3+
import numpy as np
4+
5+
from graphgen.bases import BaseValueFilter
6+
7+
8+
class RangeFilter(BaseValueFilter):
9+
"""
10+
keeps values within a specified range [min_val, max_val] (inclusive or exclusive)
11+
"""
12+
13+
def __init__(
14+
self,
15+
min_val: float,
16+
max_val: float,
17+
left_inclusive: bool = True,
18+
right_inclusive: bool = True,
19+
):
20+
self.min_val = min_val
21+
self.max_val = max_val
22+
self.left_inclusive = left_inclusive
23+
self.right_inclusive = right_inclusive
24+
25+
def filter(self, data: Union[int, float, np.number]) -> bool:
26+
value = float(data)
27+
if self.left_inclusive and self.right_inclusive:
28+
return self.min_val <= value <= self.max_val
29+
if self.left_inclusive and not self.right_inclusive:
30+
return self.min_val <= value < self.max_val
31+
if not self.left_inclusive and self.right_inclusive:
32+
return self.min_val < value <= self.max_val
33+
return self.min_val < value < self.max_val
34+
35+
@property
36+
def filter_type(self) -> str:
37+
return "range"
38+
39+
def __repr__(self) -> str:
40+
return f"RangeFilter({self.min_val}, {self.max_val})"

graphgen/operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .chunk import ChunkService
33
from .evaluate import EvaluateService
44
from .extract import ExtractService
5+
from .filter import FilterService
56
from .generate import GenerateService
67
from .judge import JudgeService
78
from .partition import PartitionService
@@ -22,4 +23,5 @@
2223
"generate": GenerateService,
2324
"evaluate": EvaluateService,
2425
"rephrase": RephraseService,
26+
"filter": FilterService,
2527
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .filter_service import FilterService

0 commit comments

Comments
 (0)