Skip to content

Commit 9da4e9b

Browse files
refactor: use kv_storage for cache of ParallelFileScanner
1 parent 97a03f2 commit 9da4e9b

9 files changed

Lines changed: 134 additions & 129 deletions

File tree

graphgen/bases/base_storage.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ def filter_keys(self, data: list[str]) -> set[str]:
3939
def upsert(self, data: dict[str, T]):
4040
raise NotImplementedError
4141

42+
def update(self, data: dict[str, T]):
43+
raise NotImplementedError
44+
45+
def delete(self, ids: list[str]):
46+
raise NotImplementedError
47+
4248
def drop(self):
4349
raise NotImplementedError
4450

graphgen/common/init_storage.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ def filter_keys(self, data: list[str]) -> set[str]:
4242
def upsert(self, data: dict) -> dict:
4343
return self.kv.upsert(data)
4444

45+
def update(self, data: dict):
46+
return self.kv.update(data)
47+
48+
def delete(self, ids: list[str]):
49+
return self.kv.delete(ids)
50+
4551
def drop(self):
4652
return self.kv.drop()
4753

@@ -168,6 +174,12 @@ def filter_keys(self, data: list[str]) -> set[str]:
168174
def upsert(self, data: Dict[str, Any]):
169175
return ray.get(self.actor.upsert.remote(data))
170176

177+
def update(self, data: Dict[str, Any]):
178+
return ray.get(self.actor.update.remote(data))
179+
180+
def delete(self, ids: list[str]):
181+
return ray.get(self.actor.delete.remote(ids))
182+
171183
def drop(self):
172184
return ray.get(self.actor.drop.remote())
173185

graphgen/models/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,5 @@
4444
from .searcher.web.bing_search import BingSearch
4545
from .searcher.web.google_search import GoogleSearch
4646
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
47-
from .storage import (
48-
JsonKVStorage,
49-
KuzuStorage,
50-
NetworkXStorage,
51-
RocksDBCache,
52-
RocksDBKVStorage,
53-
)
47+
from .storage import JsonKVStorage, KuzuStorage, NetworkXStorage, RocksDBKVStorage
5448
from .tokenizer import Tokenizer

graphgen/models/storage/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,3 @@
22
from graphgen.models.storage.graph.networkx_storage import NetworkXStorage
33
from graphgen.models.storage.kv.json_storage import JsonKVStorage
44
from graphgen.models.storage.kv.rocksdb_storage import RocksDBKVStorage
5-
6-
from .rocksdb_cache import RocksDBCache

graphgen/models/storage/kv/json_storage.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from dataclasses import dataclass
33

4-
from graphgen.bases.base_storage import BaseKVStorage
4+
from graphgen.bases.base_storage import BaseKVStorage, T
55
from graphgen.utils import load_json, write_json
66

77

@@ -51,6 +51,15 @@ def upsert(self, data: dict):
5151
self._data.update(left_data)
5252
return left_data
5353

54+
def update(self, data: dict[str, T]):
55+
for k, v in data.items():
56+
self._data[k] = v
57+
58+
def delete(self, ids: list[str]):
59+
for _id in ids:
60+
if _id in self._data:
61+
del self._data[_id]
62+
5463
def drop(self):
5564
if self._data:
5665
self._data.clear()

graphgen/models/storage/kv/rocksdb_storage.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ def upsert(self, data: Dict[str, Any]):
6868

6969
return left_data
7070

71+
def update(self, data: Dict[str, Any]):
72+
for k, v in data.items():
73+
self._db[k] = v
74+
75+
def delete(self, ids: List[str]):
76+
for _id in ids:
77+
if _id in self._db:
78+
del self._db[_id]
79+
7180
def drop(self):
7281
self._db.close()
7382
Rdict.destroy(self._db_path)

graphgen/models/storage/rocksdb_cache.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

graphgen/operators/read/parallel_file_scanner.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,22 @@
22
import time
33
from concurrent.futures import ThreadPoolExecutor, as_completed
44
from pathlib import Path
5-
from typing import Any, Dict, List, Set, Union
5+
from typing import Any, Dict, List, Optional, Set, Union
66

7-
from graphgen.models import RocksDBCache
7+
from graphgen.bases import BaseKVStorage
8+
from graphgen.utils import compute_content_hash, logger
89

910

1011
class ParallelFileScanner:
1112
def __init__(
12-
self, cache_dir: str, allowed_suffix, rescan: bool = False, max_workers: int = 4
13+
self,
14+
read_cache: BaseKVStorage,
15+
allowed_suffix: Optional[List[str]] = None,
16+
rescan: bool = False,
17+
max_workers: int = 4,
1318
):
14-
self.cache = RocksDBCache(os.path.join(cache_dir, "input_paths.db"))
15-
self.allowed_suffix = set(allowed_suffix) if allowed_suffix else None
19+
self.cache = read_cache
20+
self.allowed_suffix = set(allowed_suffix) if allowed_suffix else set()
1621
self.rescan = rescan
1722
self.max_workers = max_workers
1823

@@ -55,8 +60,10 @@ def _scan_files(
5560
return self._empty_result(path_str)
5661

5762
# cache check
58-
cache_key = f"scan::{path_str}::recursive::{recursive}"
59-
cached = self.cache.get(cache_key)
63+
cache_key = compute_content_hash(
64+
f"scan::{path_str}::recursive::{recursive}", prefix="read-"
65+
)
66+
cached = self.cache.get_by_id(cache_key)
6067
if cached and not self.rescan:
6168
return cached["data"]
6269

@@ -66,7 +73,9 @@ def _scan_files(
6673
try:
6774
path_stat = path.stat()
6875
if path.is_file():
69-
return self._scan_single_file(path, path_str, path_stat)
76+
result = self._scan_single_file(path, path_str, path_stat)
77+
self._cache_result(cache_key, result, path)
78+
return result
7079
if path.is_dir():
7180
with os.scandir(path_str) as entries:
7281
for entry in entries:
@@ -113,6 +122,12 @@ def _scan_files(
113122
stats["file_count"] += sub_data["stats"].get("file_count", 0)
114123

115124
result = {"path": path_str, "files": files, "dirs": dirs, "stats": stats}
125+
logger.debug(
126+
"Scanned %s: %d files, %d dirs",
127+
path_str,
128+
stats["file_count"],
129+
stats["dir_count"],
130+
)
116131
self._cache_result(cache_key, result, path)
117132
return result
118133

@@ -174,31 +189,26 @@ def _scan_subdirs(self, dir_list: List[Dict], visited: Set[str]) -> Dict[str, An
174189

175190
def _cache_result(self, key: str, result: Dict, path: Path):
176191
"""Cache the scan result"""
177-
self.cache.set(
178-
key,
192+
self.cache.upsert(
179193
{
180-
"data": result,
181-
"dir_mtime": path.stat().st_mtime,
182-
"cached_at": time.time(),
183-
},
194+
key: {
195+
"data": result,
196+
"dir_mtime": path.stat().st_mtime,
197+
"cached_at": time.time(),
198+
},
199+
}
184200
)
185201

186202
def _is_allowed_file(self, path: Path) -> bool:
187203
"""Check if the file has an allowed suffix"""
188-
if self.allowed_suffix is None:
204+
if not self.allowed_suffix or len(self.allowed_suffix) == 0:
189205
return True
190206
suffix = path.suffix.lower().lstrip(".")
191207
return suffix in self.allowed_suffix
192208

193-
def invalidate(self, path: str):
194-
"""Invalidate cache for a specific path"""
195-
path = Path(path).resolve()
196-
keys = [k for k in self.cache if k.startswith(f"scan::{path}")]
197-
for k in keys:
198-
self.cache.delete(k)
199-
200209
def close(self):
201-
self.cache.close()
210+
self.cache.index_done_callback()
211+
del self.cache
202212

203213
def __enter__(self):
204214
return self

graphgen/operators/read/read.py

Lines changed: 63 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import ray
55

6+
from graphgen.common import init_storage
67
from graphgen.models import (
78
CSVReader,
89
JSONReader,
@@ -51,6 +52,7 @@ def read(
5152
input_path: Union[str, List[str]],
5253
allowed_suffix: Optional[List[str]] = None,
5354
working_dir: Optional[str] = "cache",
55+
kv_backend: str = "rocksdb",
5456
parallelism: int = 4,
5557
recursive: bool = True,
5658
read_nums: Optional[int] = None,
@@ -62,71 +64,79 @@ def read(
6264
:param input_path: File or directory path(s) to read from
6365
:param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt'])
6466
:param working_dir: Directory to cache intermediate files (PDF processing)
67+
:param kv_backend: Backend for key-value storage
6568
:param parallelism: Number of parallel workers
6669
:param recursive: Whether to scan directories recursively
6770
:param read_nums: Limit the number of documents to read
6871
:param reader_kwargs: Additional kwargs passed to readers
6972
:return: Ray Dataset containing all documents
7073
"""
74+
75+
read_cache = init_storage(
76+
backend=kv_backend, working_dir=working_dir, namespace="read"
77+
)
7178
try:
7279
# 1. Scan all paths to discover files
7380
logger.info("[READ] Scanning paths: %s", input_path)
74-
scanner = ParallelFileScanner(
75-
cache_dir=working_dir,
81+
with ParallelFileScanner(
82+
read_cache=read_cache,
7683
allowed_suffix=allowed_suffix,
7784
rescan=False,
7885
max_workers=parallelism if parallelism > 0 else 1,
79-
)
80-
81-
all_files = []
82-
scan_results = scanner.scan(input_path, recursive=recursive)
83-
84-
for result in scan_results.values():
85-
all_files.extend(result.get("files", []))
86-
87-
logger.info("[READ] Found %d files to process", len(all_files))
88-
89-
if not all_files:
90-
raise ValueError("No files found to read.")
91-
92-
# 2. Group files by suffix to use appropriate reader
93-
files_by_suffix = {}
94-
for file_info in all_files:
95-
suffix = Path(file_info["path"]).suffix.lower().lstrip(".")
96-
if allowed_suffix and suffix not in [
97-
s.lower().lstrip(".") for s in allowed_suffix
98-
]:
99-
continue
100-
files_by_suffix.setdefault(suffix, []).append(file_info["path"])
101-
102-
# 3. Create read tasks
103-
read_tasks = []
104-
for suffix, file_paths in files_by_suffix.items():
105-
reader = _build_reader(suffix, working_dir, **reader_kwargs)
106-
ds = reader.read(file_paths)
107-
read_tasks.append(ds)
108-
109-
# 4. Combine all datasets
110-
if not read_tasks:
111-
raise ValueError("No datasets created from the provided files.")
112-
113-
if len(read_tasks) == 1:
114-
combined_ds = read_tasks[0]
115-
else:
116-
combined_ds = read_tasks[0].union(*read_tasks[1:])
117-
118-
combined_ds = combined_ds.map(
119-
lambda record: {
120-
**record,
121-
"_doc_id": compute_mm_hash(record, prefix="doc-"),
122-
}
123-
)
124-
125-
if read_nums is not None:
126-
combined_ds = combined_ds.limit(read_nums)
127-
128-
logger.info("[READ] Successfully read files from %s", input_path)
129-
return combined_ds
86+
) as scanner:
87+
all_files = []
88+
scan_results = scanner.scan(input_path, recursive=recursive)
89+
90+
for result in scan_results.values():
91+
all_files.extend(result.get("files", []))
92+
93+
logger.info("[READ] Found %d files to process", len(all_files))
94+
95+
if not all_files:
96+
raise ValueError("No files found to read.")
97+
98+
# 2. Group files by suffix to use appropriate reader
99+
files_by_suffix = {}
100+
for file_info in all_files:
101+
suffix = Path(file_info["path"]).suffix.lower().lstrip(".")
102+
if allowed_suffix and suffix not in [
103+
s.lower().lstrip(".") for s in allowed_suffix
104+
]:
105+
continue
106+
files_by_suffix.setdefault(suffix, []).append(file_info["path"])
107+
108+
# 3. Create read tasks
109+
read_tasks = []
110+
for suffix, file_paths in files_by_suffix.items():
111+
reader = _build_reader(suffix, working_dir, **reader_kwargs)
112+
ds = reader.read(file_paths)
113+
read_tasks.append(ds)
114+
115+
# 4. Combine all datasets
116+
if not read_tasks:
117+
raise ValueError("No datasets created from the provided files.")
118+
119+
if len(read_tasks) == 1:
120+
combined_ds = read_tasks[0]
121+
else:
122+
combined_ds = read_tasks[0].union(*read_tasks[1:])
123+
124+
combined_ds = combined_ds.map(
125+
lambda record: {
126+
**record,
127+
"_trace_id": compute_mm_hash(record, prefix="doc-"),
128+
}
129+
)
130+
131+
if read_nums is not None:
132+
combined_ds = combined_ds.limit(read_nums)
133+
134+
# sample record
135+
for i, item in enumerate(combined_ds.take(1)):
136+
logger.debug("[READ] Sample record %d: %s", i, item)
137+
138+
logger.info("[READ] Successfully read files from %s", input_path)
139+
return combined_ds
130140

131141
except Exception as e:
132142
logger.error("[READ] Failed to read files from %s: %s", input_path, e)

0 commit comments

Comments
 (0)