Skip to content

Commit 661211e

Browse files
feat: add FilterService (#169)
* feat: add FilterService * Update examples/filter/filter.sh Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 4c73f42 commit 661211e

11 files changed

Lines changed: 282 additions & 58 deletions

File tree

examples/filter/filter.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/filter/filter_config.yaml

examples/filter/filter_config.yaml

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
global_params:
2+
working_dir: cache
3+
graph_backend: networkx # graph database backend, support: kuzu, networkx
4+
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv
5+
6+
nodes:
7+
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
8+
op_name: read
9+
type: source
10+
dependencies: []
11+
params:
12+
input_path:
13+
- examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples
14+
15+
- id: chunk_documents
16+
op_name: chunk
17+
type: map_batch
18+
dependencies:
19+
- read_files
20+
execution_params:
21+
replicas: 4
22+
params:
23+
chunk_size: 1024 # chunk size for text splitting
24+
chunk_overlap: 100 # chunk overlap for text splitting
25+
26+
- id: build_kg
27+
op_name: build_kg
28+
type: map_batch
29+
dependencies:
30+
- chunk_documents
31+
execution_params:
32+
replicas: 1
33+
batch_size: 128
34+
35+
- id: quiz
36+
op_name: quiz
37+
type: aggregate
38+
dependencies:
39+
- build_kg
40+
execution_params:
41+
replicas: 1
42+
batch_size: 128
43+
params:
44+
quiz_samples: 2 # number of quiz samples to generate
45+
concurrency_limit: 200
46+
47+
- id: judge
48+
op_name: judge
49+
type: map_batch
50+
dependencies:
51+
- quiz
52+
execution_params:
53+
replicas: 1
54+
batch_size: 128
55+
56+
- id: partition
57+
op_name: partition
58+
type: aggregate
59+
dependencies:
60+
- judge
61+
params:
62+
method: ece # ece is a custom partition method based on comprehension loss
63+
method_params:
64+
max_units_per_community: 20 # max nodes and edges per community
65+
min_units_per_community: 5 # min nodes and edges per community
66+
max_tokens_per_community: 10240 # max tokens per community
67+
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
68+
69+
- id: generate
70+
op_name: generate
71+
type: map_batch
72+
dependencies:
73+
- partition
74+
execution_params:
75+
replicas: 1
76+
batch_size: 128
77+
save_output: true
78+
params:
79+
method: aggregated # atomic, aggregated, multi_hop, cot, vqa
80+
data_format: ChatML # Alpaca, Sharegpt, ChatML
81+
82+
- id: evaluate
83+
op_name: evaluate
84+
type: map_batch
85+
dependencies:
86+
- generate
87+
execution_params:
88+
replicas: 1
89+
batch_size: 128
90+
save_output: true
91+
params:
92+
target: qa
93+
metrics:
94+
- length
95+
- mtld
96+
# - reward_score
97+
# - uni_score
98+
mtld_params:
99+
threshold: 0.7
100+
101+
- id: filter
102+
op_name: filter
103+
type: filter
104+
dependencies:
105+
- evaluate
106+
execution_params:
107+
replicas: 1
108+
batch_size: 128
109+
save_output: true
110+
params:
111+
method: range
112+
method_params:
113+
metric: mtld
114+
min_val: 300
115+
max_val: 400
116+

graphgen/bases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base_evaluator import BaseKGEvaluator, BaseQAEvaluator, BaseTripleEvaluator
22
from .base_extractor import BaseExtractor
3+
from .base_filter import BaseValueFilter
34
from .base_generator import BaseGenerator
45
from .base_kg_builder import BaseKGBuilder
56
from .base_llm_wrapper import BaseLLMWrapper

graphgen/bases/base_filter.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Union
3+
4+
import numpy as np
5+
6+
7+
class BaseFilter(ABC):
8+
@abstractmethod
9+
def filter(self, data: Any) -> bool:
10+
"""
11+
Filter the data and return True if it passes the filter, False otherwise.
12+
"""
13+
raise NotImplementedError
14+
15+
16+
class BaseValueFilter(BaseFilter, ABC):
17+
@abstractmethod
18+
def filter(self, data: Union[int, float, np.number]) -> bool:
19+
"""
20+
Filter the numeric value and return True if it passes the filter, False otherwise.
21+
"""
22+
raise NotImplementedError
23+
24+
@property
25+
@abstractmethod
26+
def filter_type(self) -> str:
27+
"""
28+
Return the type of filter (e.g., "greater_than", "less_than", etc.)
29+
"""
30+
raise NotImplementedError

graphgen/engine.py

Lines changed: 39 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import logging
33
import os
44
from collections import defaultdict, deque
5-
from functools import wraps
65
from typing import Any, Callable, Dict, List, Set
76

87
import ray
@@ -103,7 +102,6 @@ def _scan_storage_requirements(self) -> tuple[set[str], set[str]]:
103102
kv_namespaces = set()
104103
graph_namespaces = set()
105104

106-
# TODO: Temporarily hard-coded; node storage will be centrally managed later.
107105
for node in self.config.nodes:
108106
op_name = node.op_name
109107
if self._function_needs_param(op_name, "kv_backend"):
@@ -232,62 +230,38 @@ def _filter_kwargs(
232230

233231
input_ds = self._get_input_dataset(node, initial_ds)
234232

235-
if inspect.isclass(op_handler):
236-
execution_params = node.execution_params or {}
237-
replicas = execution_params.get("replicas", 1)
238-
batch_size = (
239-
int(execution_params.get("batch_size"))
240-
if "batch_size" in execution_params
241-
else "default"
233+
# if inspect.isclass(op_handler):
234+
execution_params = node.execution_params or {}
235+
replicas = execution_params.get("replicas", 1)
236+
batch_size = (
237+
int(execution_params.get("batch_size"))
238+
if "batch_size" in execution_params
239+
else "default"
240+
)
241+
compute_resources = execution_params.get("compute_resources", {})
242+
243+
if node.type == "aggregate":
244+
self.datasets[node.id] = input_ds.repartition(1).map_batches(
245+
op_handler,
246+
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1),
247+
batch_size=None, # aggregate processes the whole dataset at once
248+
num_gpus=compute_resources.get("num_gpus", 0)
249+
if compute_resources
250+
else 0,
251+
fn_constructor_kwargs=node_params,
252+
batch_format="pandas",
242253
)
243-
compute_resources = execution_params.get("compute_resources", {})
244-
245-
if node.type == "aggregate":
246-
self.datasets[node.id] = input_ds.repartition(1).map_batches(
247-
op_handler,
248-
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1),
249-
batch_size=None, # aggregate processes the whole dataset at once
250-
num_gpus=compute_resources.get("num_gpus", 0)
251-
if compute_resources
252-
else 0,
253-
fn_constructor_kwargs=node_params,
254-
batch_format="pandas",
255-
)
256-
else:
257-
# others like map, filter, flatmap, map_batch let actors process data inside batches
258-
self.datasets[node.id] = input_ds.map_batches(
259-
op_handler,
260-
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas),
261-
batch_size=batch_size,
262-
num_gpus=compute_resources.get("num_gpus", 0)
263-
if compute_resources
264-
else 0,
265-
fn_constructor_kwargs=node_params,
266-
batch_format="pandas",
267-
)
268-
269254
else:
270-
271-
@wraps(op_handler)
272-
def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
273-
return op_handler(row_or_batch, **node_params)
274-
275-
if node.type == "map":
276-
self.datasets[node.id] = input_ds.map(func_wrapper)
277-
elif node.type == "filter":
278-
self.datasets[node.id] = input_ds.filter(func_wrapper)
279-
elif node.type == "flatmap":
280-
self.datasets[node.id] = input_ds.flat_map(func_wrapper)
281-
elif node.type == "aggregate":
282-
self.datasets[node.id] = input_ds.repartition(1).map_batches(
283-
func_wrapper, batch_format="default"
284-
)
285-
elif node.type == "map_batch":
286-
self.datasets[node.id] = input_ds.map_batches(func_wrapper)
287-
else:
288-
raise ValueError(
289-
f"Unsupported node type {node.type} for node {node.id}"
290-
)
255+
self.datasets[node.id] = input_ds.map_batches(
256+
op_handler,
257+
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas),
258+
batch_size=batch_size,
259+
num_gpus=compute_resources.get("num_gpus", 0)
260+
if compute_resources
261+
else 0,
262+
fn_constructor_kwargs=node_params,
263+
batch_format="pandas",
264+
)
291265

292266
def execute(
293267
self, initial_ds: ray.data.Dataset, output_dir: str
@@ -315,6 +289,14 @@ def execute(
315289
logger.info("Node %s output saved to %s", node.id, node_output_path)
316290

317291
# ray will lazy read the dataset
318-
self.datasets[node.id] = ray.data.read_json(node_output_path)
292+
if os.path.exists(node_output_path) and os.listdir(node_output_path):
293+
self.datasets[node.id] = ray.data.read_json(node_output_path)
294+
else:
295+
self.datasets[node.id] = ray.data.from_items([])
296+
logger.warning(
297+
"Node %s output path %s is empty. Created an empty dataset.",
298+
node.id,
299+
node_output_path,
300+
)
319301

320302
return self.datasets

graphgen/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
StructureEvaluator,
77
UniEvaluator,
88
)
9+
from .filter import RangeFilter
910
from .generator import (
1011
AggregatedGenerator,
1112
AtomicGenerator,

graphgen/models/filter/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .range_filter import RangeFilter
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Union
2+
3+
import numpy as np
4+
5+
from graphgen.bases import BaseValueFilter
6+
7+
8+
class RangeFilter(BaseValueFilter):
9+
"""
10+
keeps values within a specified range [min_val, max_val] (inclusive or exclusive)
11+
"""
12+
13+
def __init__(
14+
self,
15+
min_val: float,
16+
max_val: float,
17+
left_inclusive: bool = True,
18+
right_inclusive: bool = True,
19+
):
20+
self.min_val = min_val
21+
self.max_val = max_val
22+
self.left_inclusive = left_inclusive
23+
self.right_inclusive = right_inclusive
24+
25+
def filter(self, data: Union[int, float, np.number]) -> bool:
26+
value = float(data)
27+
if self.left_inclusive and self.right_inclusive:
28+
return self.min_val <= value <= self.max_val
29+
if self.left_inclusive and not self.right_inclusive:
30+
return self.min_val <= value < self.max_val
31+
if not self.left_inclusive and self.right_inclusive:
32+
return self.min_val < value <= self.max_val
33+
return self.min_val < value < self.max_val
34+
35+
@property
36+
def filter_type(self) -> str:
37+
return "range"
38+
39+
def __repr__(self) -> str:
40+
return f"RangeFilter({self.min_val}, {self.max_val})"

graphgen/operators/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
from .chunk import ChunkService
33
from .evaluate import EvaluateService
44
from .extract import ExtractService
5+
from .filter import FilterService
56
from .generate import GenerateService
67
from .judge import JudgeService
78
from .partition import PartitionService
89
from .quiz import QuizService
910
from .read import read
1011
from .search import SearchService
1112

12-
1313
operators = {
1414
"read": read,
1515
"chunk": ChunkService,
@@ -21,4 +21,5 @@
2121
"partition": PartitionService,
2222
"generate": GenerateService,
2323
"evaluate": EvaluateService,
24+
"filter": FilterService,
2425
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .filter_service import FilterService

0 commit comments

Comments
 (0)