Skip to content

Commit 7d729ab

Browse files
author
Jin Zhou
committed
support dask expressions
1 parent 2a6d2ca commit 7d729ab

7 files changed

Lines changed: 134 additions & 89 deletions

File tree

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
adaptor_test.py
2+
__pycache__/

taskvine/src/graph/dagvine/blueprint_graph/adaptor.py

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# where the private module is unavailable or type-checkers can't resolve it).
2020
dts = None
2121

22-
from ndcctools.taskvine.dagvine.blueprint_graph.blueprint_graph import TaskOutputRef
22+
from ndcctools.taskvine.dagvine.blueprint_graph.blueprint_graph import TaskOutputRef, BlueprintGraph
2323

2424

2525
def _identity(value):
@@ -32,8 +32,11 @@ class Adaptor:
3232

3333
_LEAF_TYPES = (str, bytes, bytearray, memoryview, int, float, bool, type(None))
3434

35-
def __init__(self, collection_dict):
36-
self.original_collection_dict = collection_dict
35+
def __init__(self, task_dict):
36+
37+
if isinstance(task_dict, BlueprintGraph):
38+
self.converted = task_dict
39+
return
3740

3841
# TaskSpec-only state used to "lift" inline Tasks that cannot be reduced to
3942
# a pure Python value (or would be unsafe/expensive to inline).
@@ -45,23 +48,23 @@ def __init__(self, collection_dict):
4548
# lifted keys remain visible across subsequent conversions/dedup/reference checks.
4649
self._task_keys = set()
4750

48-
normalized = self._normalize_task_dict(collection_dict)
49-
self.task_dict = self._convert_to_blueprint_tasks(normalized)
51+
normalized = self._normalize_task_dict(task_dict)
52+
self.converted = self._convert_to_blueprint_tasks(normalized)
5053

51-
def _normalize_task_dict(self, collection_dict):
54+
def _normalize_task_dict(self, task_dict):
5255
"""Collapse every supported input style into a classic `{key: sexpr or TaskSpec}` mapping."""
5356
from_dask_collection = bool(
54-
is_dask_collection and any(is_dask_collection(v) for v in collection_dict.values())
57+
is_dask_collection and any(is_dask_collection(v) for v in task_dict.values())
5558
)
5659

5760
if from_dask_collection:
58-
task_dict = self._dask_collections_to_task_dict(collection_dict)
61+
task_dict = self._dask_collections_to_task_dict(task_dict)
5962
else:
6063
# IMPORTANT: treat plain user dicts as DAGVine sexprs by default.
6164
# If we unconditionally run `dask._task_spec.convert_legacy_graph(...)` when
6265
# dts is available, Dask will interpret our "final Mapping is kwargs"
6366
# convention as a positional dict argument, breaking sexpr semantics.
64-
task_dict = dict(collection_dict)
67+
task_dict = dict(task_dict)
6568

6669
# Only ask Dask to rewrite legacy graphs when we *know* the input came
6770
# from a Dask collection/HLG. This keeps classic DAGVine sexprs stable
@@ -227,36 +230,38 @@ def _should_wrap(self, obj, task_keys):
227230
"""Decide whether a value should become a `TaskOutputRef`."""
228231
if isinstance(obj, self._LEAF_TYPES):
229232
if isinstance(obj, str):
230-
return obj in task_keys
233+
hit = obj in task_keys
234+
return hit
231235
return False
232236
try:
233-
return obj in task_keys
237+
hit = obj in task_keys
238+
return hit
234239
except TypeError:
235240
return False
236241

237242
# Flatten Dask collections into the dict-of-tasks structure the rest of the
238243
# pipeline expects. DAGVine clients often hand us a dict like
239244
# {"result": dask.delayed(...)}; we merge the underlying HighLevelGraphs so
240245
# `ContextGraph` sees the same dict representation C does.
241-
def _dask_collections_to_task_dict(self, collection_dict):
246+
def _dask_collections_to_task_dict(self, task_dict):
242247
"""Flatten Dask collections into the classic dict-of-task layout."""
243248
assert is_dask_collection is not None
244249
from dask.highlevelgraph import HighLevelGraph, ensure_dict
245250

246-
if not isinstance(collection_dict, dict):
251+
if not isinstance(task_dict, dict):
247252
raise TypeError("Input must be a dict")
248253

249-
for k, v in collection_dict.items():
254+
for k, v in task_dict.items():
250255
if not is_dask_collection(v):
251256
raise TypeError(
252257
f"Input must be a dict of DaskCollection, but found {k} with type {type(v)}"
253258
)
254259

255260
if dts:
256-
sub_hlgs = [v.dask for v in collection_dict.values()]
261+
sub_hlgs = [v.dask for v in task_dict.values()]
257262
hlg = HighLevelGraph.merge(*sub_hlgs).to_dict()
258263
else:
259-
hlg = dask.base.collections_to_dsk(collection_dict.values())
264+
hlg = dask.base.collections_to_dsk(task_dict.values())
260265

261266
return ensure_dict(hlg)
262267

@@ -299,18 +304,21 @@ def _unwrap_dts_operand(self, operand, task_keys, *, parent_key=None):
299304

300305
literal_cls = getattr(dts, "Literal", None)
301306
if literal_cls and isinstance(operand, literal_cls):
302-
return getattr(operand, "value", None)
307+
value = getattr(operand, "value", None)
308+
return value
303309

304310
datanode_cls = getattr(dts, "DataNode", None)
305311
if datanode_cls and isinstance(operand, datanode_cls):
306-
return operand.value
312+
value = operand.value
313+
return value
307314

308315
nested_cls = getattr(dts, "NestedContainer", None)
309316
if nested_cls and isinstance(operand, nested_cls):
310317
payload = getattr(operand, "value", None)
311318
if payload is None:
312319
payload = getattr(operand, "data", None)
313-
return self._unwrap_dts_operand(payload, task_keys, parent_key=parent_key)
320+
value = self._unwrap_dts_operand(payload, task_keys, parent_key=parent_key)
321+
return value
314322

315323
task_cls = getattr(dts, "Task", None)
316324
if task_cls and isinstance(operand, task_cls):
@@ -323,15 +331,31 @@ def _unwrap_dts_operand(self, operand, task_keys, *, parent_key=None):
323331
# Otherwise it is an inline expression. Reduce if safe, else lift.
324332
func = self._extract_callable_from_task(operand)
325333
if func is None:
326-
return self._lift_inline_task(operand, task_keys, parent_key=parent_key)
334+
out = self._lift_inline_task(operand, task_keys, parent_key=parent_key)
335+
return out
327336

328337
# Special-case: Dask internal identity-cast wrappers should not be called
329-
# during adaptation. Reduce structurally by returning the first argument.
338+
# during adaptation. Reduce structurally by unwrapping all args and
339+
# rebuilding the requested container type. This preserves dependency
340+
# edges (critical for WCC) without executing arbitrary code.
330341
if self._is_identity_cast_op(func):
331342
raw_args = getattr(operand, "args", ()) or ()
332-
if not raw_args:
333-
return None
334-
return self._unwrap_dts_operand(raw_args[0], task_keys, parent_key=parent_key)
343+
raw_kwargs = getattr(operand, "kwargs", {}) or {}
344+
typ = raw_kwargs.get("typ", None)
345+
346+
values = [self._unwrap_dts_operand(a, task_keys, parent_key=parent_key) for a in raw_args]
347+
348+
# Only allow safe container constructors here; otherwise lift.
349+
safe_types = (list, tuple, set, frozenset, dict)
350+
if typ in safe_types:
351+
try:
352+
casted = typ(values)
353+
except Exception:
354+
return self._lift_inline_task(operand, task_keys, parent_key=parent_key)
355+
return casted
356+
357+
# Unknown/unsafe typ: lift so the worker executes the real op.
358+
return self._lift_inline_task(operand, task_keys, parent_key=parent_key)
335359

336360
if self._is_pure_value_op(func):
337361
reduced, used_lift = self._reduce_inline_task(operand, task_keys, parent_key=parent_key)

taskvine/src/graph/dagvine/blueprint_graph/adaptor_test.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -501,15 +501,20 @@ def fake_identity_cast(x, *_, **__):
501501
fake_identity_cast.__module__ = "dask._fake"
502502

503503
graph = {
504-
"raw": _FakeDataNode(value=5),
504+
"raw0": _FakeDataNode(value=5),
505+
"raw1": _FakeDataNode(value=6),
505506
"outer": _FakeTask(
506507
key="outer",
507508
function=lambda x: x,
508509
args=(
509510
_FakeTask(
510511
key=None,
511512
function=fake_identity_cast,
512-
args=(_FakeTaskRef("raw"),),
513+
args=(
514+
_FakeTaskRef("raw0"),
515+
_FakeTaskRef("raw1"),
516+
),
517+
kwargs={"typ": list},
513518
),
514519
),
515520
),
@@ -519,8 +524,12 @@ def fake_identity_cast(x, *_, **__):
519524
_, outer_args, outer_kwargs = adapted["outer"]
520525
self.assertEqual(outer_kwargs, {})
521526
self.assertEqual(len(outer_args), 1)
522-
self.assertIsInstance(outer_args[0], TaskOutputRef)
523-
self.assertEqual(outer_args[0].task_key, "raw")
527+
self.assertIsInstance(outer_args[0], list)
528+
self.assertEqual(len(outer_args[0]), 2)
529+
self.assertIsInstance(outer_args[0][0], TaskOutputRef)
530+
self.assertIsInstance(outer_args[0][1], TaskOutputRef)
531+
self.assertEqual(outer_args[0][0].task_key, "raw0")
532+
self.assertEqual(outer_args[0][1].task_key, "raw1")
524533
self.assertFalse(any(str(k).startswith("__lift__") for k in adapted.keys()))
525534

526535

taskvine/src/graph/dagvine/blueprint_graph/blueprint_graph.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import cloudpickle
99

1010

11+
# Lightweight wrapper around task results that optionally pads the payload. The
12+
# padding lets tests model large outputs without altering the logical result.
1113
class TaskOutputWrapper:
1214
def __init__(self, result, extra_size_mb=None):
1315
self.result = result
@@ -24,6 +26,7 @@ def load_from_path(path):
2426
raise FileNotFoundError(f"Task result file not found at {path}")
2527

2628

29+
# A reference to a task output. This is used to represent the output of a task as a dependency of another task.
2730
class TaskOutputRef:
2831
__slots__ = ("task_key", "path")
2932

@@ -37,6 +40,8 @@ def __getitem__(self, key):
3740
return TaskOutputRef(self.task_key, self.path + (key,))
3841

3942

43+
# The BlueprintGraph is a directed acyclic graph (DAG) that represents the logical dependencies between tasks.
44+
# It is used to build the C vine graph.
4045
class BlueprintGraph:
4146

4247
_LEAF_TYPES = (str, bytes, bytearray, memoryview, int, float, bool, type(None))
@@ -55,6 +60,9 @@ def __init__(self):
5560
self.pykey2cid = {} # py_key -> c_id
5661
self.cid2pykey = {} # c_id -> py_key
5762

63+
self.extra_task_output_size_mb = {} # task_key -> extra size in MB
64+
self.extra_task_sleep_time = {} # task_key -> extra sleep time in seconds
65+
5866
def _visit_task_output_refs(self, obj, on_ref, *, rewrite: bool):
5967
seen = set()
6068

@@ -154,7 +162,7 @@ def task_consumes(self, task_key, *filenames):
154162

155163
def save_task_output(self, task_key, output):
156164
with open(self.outfile_remote_name[task_key], "wb") as f:
157-
wrapped_output = TaskOutputWrapper(output, extra_size_mb=0)
165+
wrapped_output = TaskOutputWrapper(output, extra_size_mb=self.extra_task_output_size_mb[task_key])
158166
cloudpickle.dump(wrapped_output, f)
159167

160168
def load_task_output(self, task_key):
@@ -191,6 +199,7 @@ def verify_topo(g, topo):
191199
print("topo verified: ok")
192200

193201
def finalize(self):
202+
# build the dependencies determined by files produced and consumed
194203
for file, producer in self.producer_of.items():
195204
for consumer in self.consumers_of.get(file, ()):
196205
self.parents_of[consumer].add(producer)

taskvine/src/graph/dagvine/blueprint_graph/proxy_functions.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55

66
from ndcctools.taskvine.utils import load_variable_from_library
7+
import time
78

89

910
def compute_task(bg, task_expr):
@@ -33,10 +34,6 @@ def on_ref(r):
3334
r_args = bg._visit_task_output_refs(args, on_ref, rewrite=True)
3435
r_kwargs = bg._visit_task_output_refs(kwargs, on_ref, rewrite=True)
3536

36-
print(f"func: {func}")
37-
print(f"r_args: {r_args}")
38-
print(f"r_kwargs: {r_kwargs}")
39-
4037
return func(*r_args, **r_kwargs)
4138

4239

@@ -48,4 +45,6 @@ def compute_single_key(vine_key):
4845

4946
output = compute_task(bg, task_expr)
5047

48+
time.sleep(bg.extra_task_sleep_time[task_key])
49+
5150
bg.save_task_output(task_key, output)

0 commit comments

Comments
 (0)