Skip to content

Commit 492c89f

Browse files
feat: add read_storage
1 parent 2166fc2 commit 492c89f

9 files changed

Lines changed: 30 additions & 38 deletions

File tree

graphgen/models/reader/csv_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def read(self, input_path: Union[str, List[str]]) -> Dataset:
2222
:return: Ray Dataset containing validated and filtered data.
2323
"""
2424

25-
ds = ray.data.read_csv(input_path)
25+
ds = ray.data.read_csv(input_path, include_paths=True)
2626
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
2727
ds = ds.filter(self._should_keep_item)
2828
return ds

graphgen/models/reader/json_reader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@ def read(self, input_path: Union[str, List[str]]) -> ray.data.Dataset:
3434
with open(file, "r", encoding="utf-8") as f:
3535
data = json.load(f)
3636
data = self._unify_schema(data)
37+
# add path
38+
for item in data:
39+
item["path"] = file
3740
file_ds: ray.data.Dataset = ray.data.from_items(data)
3841
ds = ds.union(file_ds) # type: ignore
3942
else:
40-
ds = ray.data.read_json(input_path)
43+
ds = ray.data.read_json(input_path, include_paths=True)
4144
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
4245
ds = ds.filter(self._should_keep_item)
4346
return ds

graphgen/models/reader/parquet_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def read(self, input_path: Union[str, List[str]]) -> Dataset:
2424
if not ray.is_initialized():
2525
ray.init()
2626

27-
ds = ray.data.read_parquet(input_path)
27+
ds = ray.data.read_parquet(input_path, include_paths=True)
2828
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
2929
ds = ds.filter(self._should_keep_item)
3030
return ds

graphgen/models/reader/rdf_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _parse_rdf_file(self, file_path: Path) -> List[Dict[str, Any]]:
118118
"id": str(subj),
119119
self.text_column: text,
120120
"properties": props,
121-
"source_file": str(file_path),
121+
"path": str(file_path),
122122
}
123123
docs.append(doc)
124124

graphgen/models/reader/txt_reader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@ def read(
1818
"""
1919
docs_ds = ray.data.read_binary_files(
2020
input_path,
21-
include_paths=False,
21+
include_paths=True,
2222
)
2323

2424
docs_ds = docs_ds.map(
2525
lambda row: {
2626
"type": "text",
2727
self.text_column: row["bytes"].decode("utf-8"),
28+
"path": row["path"],
2829
}
2930
)
3031

graphgen/operators/read/parallel_file_scanner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
class ParallelFileScanner:
1212
def __init__(
1313
self,
14-
read_cache: BaseKVStorage,
14+
input_path_cache: BaseKVStorage,
1515
allowed_suffix: Optional[List[str]] = None,
1616
rescan: bool = False,
1717
max_workers: int = 4,
1818
):
19-
self.cache = read_cache
19+
self.cache = input_path_cache
2020
self.allowed_suffix = set(allowed_suffix) if allowed_suffix else set()
2121
self.rescan = rescan
2222
self.max_workers = max_workers
@@ -61,7 +61,7 @@ def _scan_files(
6161

6262
# cache check
6363
cache_key = compute_content_hash(
64-
f"scan::{path_str}::recursive::{recursive}", prefix="read-"
64+
f"scan::{path_str}::recursive::{recursive}", prefix="path-"
6565
)
6666
cached = self.cache.get_by_id(cache_key)
6767
if cached and not self.rescan:

graphgen/operators/read/read.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
RDFReader,
1414
TXTReader,
1515
)
16-
from graphgen.utils import compute_mm_hash, logger
16+
from graphgen.utils import compute_dict_hash, logger
1717

1818
from .parallel_file_scanner import ParallelFileScanner
1919

@@ -71,15 +71,17 @@ def read(
7171
:param reader_kwargs: Additional kwargs passed to readers
7272
:return: Ray Dataset containing all documents
7373
"""
74-
75-
read_cache = init_storage(
74+
input_path_cache = init_storage(
75+
backend=kv_backend, working_dir=working_dir, namespace="input_path"
76+
)
77+
read_storage = init_storage(
7678
backend=kv_backend, working_dir=working_dir, namespace="read"
7779
)
7880
try:
7981
# 1. Scan all paths to discover files
8082
logger.info("[READ] Scanning paths: %s", input_path)
8183
with ParallelFileScanner(
82-
read_cache=read_cache,
84+
input_path_cache=input_path_cache,
8385
allowed_suffix=allowed_suffix,
8486
rescan=False,
8587
max_workers=parallelism if parallelism > 0 else 1,
@@ -124,12 +126,17 @@ def read(
124126
if read_nums is not None:
125127
combined_ds = combined_ds.limit(read_nums)
126128

127-
combined_ds = combined_ds.map(
128-
lambda record: {
129-
**record,
130-
"_trace_id": compute_mm_hash(record, prefix="doc-"),
131-
}
132-
)
129+
def add_trace_id(batch):
130+
batch["_trace_id"] = batch.apply(
131+
lambda row: compute_dict_hash(row, prefix="read-"), axis=1
132+
)
133+
records = batch.to_dict(orient="records")
134+
data_to_upsert = {record["_trace_id"]: record for record in records}
135+
read_storage.upsert(data_to_upsert)
136+
read_storage.index_done_callback()
137+
return batch
138+
139+
combined_ds = combined_ds.map_batches(add_trace_id, batch_format="pandas")
133140

134141
# sample record
135142
for i, item in enumerate(combined_ds.take(1)):

graphgen/utils/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@
99
split_string_by_multi_markers,
1010
write_json,
1111
)
12-
from .hash import (
13-
compute_args_hash,
14-
compute_content_hash,
15-
compute_dict_hash,
16-
compute_mm_hash,
17-
)
12+
from .hash import compute_args_hash, compute_content_hash, compute_dict_hash
1813
from .help_nltk import NLTKHelper
1914
from .log import CURRENT_LOGGER_VAR, logger, set_logger
2015
from .loop import create_event_loop

graphgen/utils/hash.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,6 @@ def compute_content_hash(content, prefix: str = ""):
99
return prefix + md5(content.encode()).hexdigest()
1010

1111

12-
def compute_mm_hash(item, prefix: str = ""):
13-
if item.get("type") == "text" and item.get("text"):
14-
content = item["text"].strip()
15-
elif item.get("type") == "image" and item.get("img_path"):
16-
content = f"image:{item['img_path']}"
17-
elif item.get("type") == "table" and item.get("table_body"):
18-
content = f"table:{item['table_body']}"
19-
elif item.get("type") == "equation" and item.get("text"):
20-
content = f"equation:{item['text']}"
21-
else:
22-
content = str(item)
23-
return prefix + md5(content.encode()).hexdigest()
24-
25-
2612
def compute_dict_hash(d: dict, prefix: str = ""):
2713
items = tuple(sorted(d.items()))
2814
return prefix + md5(str(items).encode()).hexdigest()

0 commit comments

Comments
 (0)