|
2 | 2 | import logging |
3 | 3 | import os |
4 | 4 | from collections import defaultdict, deque |
5 | | -from functools import wraps |
6 | 5 | from typing import Any, Callable, Dict, List, Set |
7 | 6 |
|
8 | 7 | import ray |
@@ -103,7 +102,6 @@ def _scan_storage_requirements(self) -> tuple[set[str], set[str]]: |
103 | 102 | kv_namespaces = set() |
104 | 103 | graph_namespaces = set() |
105 | 104 |
|
106 | | - # TODO: Temporarily hard-coded; node storage will be centrally managed later. |
107 | 105 | for node in self.config.nodes: |
108 | 106 | op_name = node.op_name |
109 | 107 | if self._function_needs_param(op_name, "kv_backend"): |
@@ -232,62 +230,38 @@ def _filter_kwargs( |
232 | 230 |
|
233 | 231 | input_ds = self._get_input_dataset(node, initial_ds) |
234 | 232 |
|
235 | | - if inspect.isclass(op_handler): |
236 | | - execution_params = node.execution_params or {} |
237 | | - replicas = execution_params.get("replicas", 1) |
238 | | - batch_size = ( |
239 | | - int(execution_params.get("batch_size")) |
240 | | - if "batch_size" in execution_params |
241 | | - else "default" |
| 233 | + # if inspect.isclass(op_handler): |
| 234 | + execution_params = node.execution_params or {} |
| 235 | + replicas = execution_params.get("replicas", 1) |
| 236 | + batch_size = ( |
| 237 | + int(execution_params.get("batch_size")) |
| 238 | + if "batch_size" in execution_params |
| 239 | + else "default" |
| 240 | + ) |
| 241 | + compute_resources = execution_params.get("compute_resources", {}) |
| 242 | + |
| 243 | + if node.type == "aggregate": |
| 244 | + self.datasets[node.id] = input_ds.repartition(1).map_batches( |
| 245 | + op_handler, |
| 246 | + compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1), |
| 247 | + batch_size=None, # aggregate processes the whole dataset at once |
| 248 | + num_gpus=compute_resources.get("num_gpus", 0) |
| 249 | + if compute_resources |
| 250 | + else 0, |
| 251 | + fn_constructor_kwargs=node_params, |
| 252 | + batch_format="pandas", |
242 | 253 | ) |
243 | | - compute_resources = execution_params.get("compute_resources", {}) |
244 | | - |
245 | | - if node.type == "aggregate": |
246 | | - self.datasets[node.id] = input_ds.repartition(1).map_batches( |
247 | | - op_handler, |
248 | | - compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1), |
249 | | - batch_size=None, # aggregate processes the whole dataset at once |
250 | | - num_gpus=compute_resources.get("num_gpus", 0) |
251 | | - if compute_resources |
252 | | - else 0, |
253 | | - fn_constructor_kwargs=node_params, |
254 | | - batch_format="pandas", |
255 | | - ) |
256 | | - else: |
257 | | - # others like map, filter, flatmap, map_batch let actors process data inside batches |
258 | | - self.datasets[node.id] = input_ds.map_batches( |
259 | | - op_handler, |
260 | | - compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas), |
261 | | - batch_size=batch_size, |
262 | | - num_gpus=compute_resources.get("num_gpus", 0) |
263 | | - if compute_resources |
264 | | - else 0, |
265 | | - fn_constructor_kwargs=node_params, |
266 | | - batch_format="pandas", |
267 | | - ) |
268 | | - |
269 | 254 | else: |
270 | | - |
271 | | - @wraps(op_handler) |
272 | | - def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]: |
273 | | - return op_handler(row_or_batch, **node_params) |
274 | | - |
275 | | - if node.type == "map": |
276 | | - self.datasets[node.id] = input_ds.map(func_wrapper) |
277 | | - elif node.type == "filter": |
278 | | - self.datasets[node.id] = input_ds.filter(func_wrapper) |
279 | | - elif node.type == "flatmap": |
280 | | - self.datasets[node.id] = input_ds.flat_map(func_wrapper) |
281 | | - elif node.type == "aggregate": |
282 | | - self.datasets[node.id] = input_ds.repartition(1).map_batches( |
283 | | - func_wrapper, batch_format="default" |
284 | | - ) |
285 | | - elif node.type == "map_batch": |
286 | | - self.datasets[node.id] = input_ds.map_batches(func_wrapper) |
287 | | - else: |
288 | | - raise ValueError( |
289 | | - f"Unsupported node type {node.type} for node {node.id}" |
290 | | - ) |
| 255 | + self.datasets[node.id] = input_ds.map_batches( |
| 256 | + op_handler, |
| 257 | + compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas), |
| 258 | + batch_size=batch_size, |
| 259 | + num_gpus=compute_resources.get("num_gpus", 0) |
| 260 | + if compute_resources |
| 261 | + else 0, |
| 262 | + fn_constructor_kwargs=node_params, |
| 263 | + batch_format="pandas", |
| 264 | + ) |
291 | 265 |
|
292 | 266 | def execute( |
293 | 267 | self, initial_ds: ray.data.Dataset, output_dir: str |
@@ -315,6 +289,14 @@ def execute( |
315 | 289 | logger.info("Node %s output saved to %s", node.id, node_output_path) |
316 | 290 |
|
317 | 291 | # ray will lazy read the dataset |
318 | | - self.datasets[node.id] = ray.data.read_json(node_output_path) |
| 292 | + if os.path.exists(node_output_path) and os.listdir(node_output_path): |
| 293 | + self.datasets[node.id] = ray.data.read_json(node_output_path) |
| 294 | + else: |
| 295 | + self.datasets[node.id] = ray.data.from_items([]) |
| 296 | + logger.warning( |
| 297 | + "Node %s output path %s is empty. Created an empty dataset.", |
| 298 | + node.id, |
| 299 | + node_output_path, |
| 300 | + ) |
319 | 301 |
|
320 | 302 | return self.datasets |
0 commit comments