|
8 | 8 |
|
9 | 9 |
|
10 | 10 | class BaseOperator(ABC): |
11 | | - def __init__(self, working_dir: str = "cache", op_name: str = None): |
| 11 | + def __init__( |
| 12 | + self, |
| 13 | + working_dir: str = "cache", |
| 14 | + kv_backend: str = "rocksdb", |
| 15 | + op_name: str = None, |
| 16 | + ): |
12 | 17 | # lazy import to avoid circular import |
| 18 | + from graphgen.common import init_storage |
13 | 19 | from graphgen.utils import set_logger |
14 | 20 |
|
15 | 21 | log_dir = os.path.join(working_dir, "logs") |
16 | 22 | self.op_name = op_name or self.__class__.__name__ |
| 23 | + self.working_dir = working_dir |
| 24 | + self.kv_storage = init_storage( |
| 25 | + backend=kv_backend, working_dir=working_dir, namespace=self.op_name |
| 26 | + ) |
17 | 27 |
|
18 | 28 | try: |
19 | 29 | ctx = ray.get_runtime_context() |
@@ -45,17 +55,80 @@ def __call__( |
45 | 55 |
|
46 | 56 | logger_token = CURRENT_LOGGER_VAR.set(self.logger) |
47 | 57 | try: |
48 | | - result = self.process(batch) |
| 58 | + self.kv_storage.reload() |
| 59 | + to_process, recovered = self.split(batch) |
| 60 | + # yield recovered chunks first |
| 61 | + if not recovered.empty: |
| 62 | + yield recovered |
| 63 | + |
| 64 | + if to_process.empty: |
| 65 | + return |
| 66 | + |
| 67 | + docs = to_process.to_dict(orient="records") |
| 68 | + result = self.process(docs) |
49 | 69 | if inspect.isgenerator(result): |
50 | 70 | yield from result |
51 | 71 | else: |
52 | 72 | yield result |
53 | 73 | finally: |
54 | 74 | CURRENT_LOGGER_VAR.reset(logger_token) |
55 | 75 |
|
56 | | - @abstractmethod |
57 | | - def process(self, batch): |
58 | | - raise NotImplementedError("Subclasses must implement the process method.") |
59 | | - |
60 | 76 | def get_logger(self): |
61 | 77 | return self.logger |
| 78 | + |
| 79 | + def get_meta_forward(self): |
| 80 | + return self.kv_storage.get_by_id("_meta_forward") or {} |
| 81 | + |
| 82 | + def get_meta_inverse(self): |
| 83 | + return self.kv_storage.get_by_id("_meta_inverse") or {} |
| 84 | + |
| 85 | + def get_trace_id(self, content: dict) -> str: |
| 86 | + from graphgen.utils import compute_dict_hash |
| 87 | + |
| 88 | + return compute_dict_hash(content, prefix=f"{self.op_name}-") |
| 89 | + |
| 90 | + def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]: |
| 91 | + """ |
| 92 | + Split the input batch into to_process & processed based on _meta data in KV_storage |
| 93 | + :param batch |
| 94 | + :return: |
| 95 | + to_process: DataFrame of documents to be chunked |
| 96 | + recovered: Result DataFrame of already chunked documents |
| 97 | + """ |
| 98 | + meta_forward = self.get_meta_forward() |
| 99 | + meta_ids = set(meta_forward.keys()) |
| 100 | + mask = batch["_trace_id"].isin(meta_ids) |
| 101 | + to_process = batch[~mask] |
| 102 | + processed = batch[mask] |
| 103 | + |
| 104 | + if processed.empty: |
| 105 | + return to_process, pd.DataFrame() |
| 106 | + |
| 107 | + all_ids = [ |
| 108 | + pid for tid in processed["_trace_id"] for pid in meta_forward.get(tid, []) |
| 109 | + ] |
| 110 | + |
| 111 | + recovered_chunks = self.kv_storage.get_by_ids(all_ids) |
| 112 | + recovered_chunks = [c for c in recovered_chunks if c is not None] |
| 113 | + return to_process, pd.DataFrame(recovered_chunks) |
| 114 | + |
| 115 | + def store(self, results: list, meta_update: dict): |
| 116 | + batch = {res["_trace_id"]: res for res in results} |
| 117 | + self.kv_storage.upsert(batch) |
| 118 | + |
| 119 | + # update forward meta |
| 120 | + forward_meta = self.get_meta_forward() |
| 121 | + forward_meta.update(meta_update) |
| 122 | + self.kv_storage.update({"_meta_forward": forward_meta}) |
| 123 | + |
| 124 | + # update inverse meta |
| 125 | + inverse_meta = self.get_meta_inverse() |
| 126 | + for k, v_list in meta_update.items(): |
| 127 | + for v in v_list: |
| 128 | + inverse_meta[v] = k |
| 129 | + self.kv_storage.update({"_meta_inverse": inverse_meta}) |
| 130 | + self.kv_storage.index_done_callback() |
| 131 | + |
| 132 | + @abstractmethod |
| 133 | + def process(self, batch: list) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]: |
| 134 | + pass |
0 commit comments