1919 # where the private module is unavailable or type-checkers can't resolve it).
2020 dts = None
2121
22- from ndcctools .taskvine .dagvine .blueprint_graph .blueprint_graph import TaskOutputRef
22+ from ndcctools .taskvine .dagvine .blueprint_graph .blueprint_graph import TaskOutputRef , BlueprintGraph
2323
2424
2525def _identity (value ):
@@ -32,8 +32,11 @@ class Adaptor:
3232
3333 _LEAF_TYPES = (str , bytes , bytearray , memoryview , int , float , bool , type (None ))
3434
35- def __init__ (self , collection_dict ):
36- self .original_collection_dict = collection_dict
35+ def __init__ (self , task_dict ):
36+
37+ if isinstance (task_dict , BlueprintGraph ):
38+ self .converted = task_dict
39+ return
3740
3841 # TaskSpec-only state used to "lift" inline Tasks that cannot be reduced to
3942 # a pure Python value (or would be unsafe/expensive to inline).
@@ -45,23 +48,23 @@ def __init__(self, collection_dict):
4548 # lifted keys remain visible across subsequent conversions/dedup/reference checks.
4649 self ._task_keys = set ()
4750
48- normalized = self ._normalize_task_dict (collection_dict )
49- self .task_dict = self ._convert_to_blueprint_tasks (normalized )
51+ normalized = self ._normalize_task_dict (task_dict )
52+ self .converted = self ._convert_to_blueprint_tasks (normalized )
5053
51- def _normalize_task_dict (self , collection_dict ):
54+ def _normalize_task_dict (self , task_dict ):
5255 """Collapse every supported input style into a classic `{key: sexpr or TaskSpec}` mapping."""
5356 from_dask_collection = bool (
54- is_dask_collection and any (is_dask_collection (v ) for v in collection_dict .values ())
57+ is_dask_collection and any (is_dask_collection (v ) for v in task_dict .values ())
5558 )
5659
5760 if from_dask_collection :
58- task_dict = self ._dask_collections_to_task_dict (collection_dict )
61+ task_dict = self ._dask_collections_to_task_dict (task_dict )
5962 else :
6063 # IMPORTANT: treat plain user dicts as DAGVine sexprs by default.
6164 # If we unconditionally run `dask._task_spec.convert_legacy_graph(...)` when
6265 # dts is available, Dask will interpret our "final Mapping is kwargs"
6366 # convention as a positional dict argument, breaking sexpr semantics.
64- task_dict = dict (collection_dict )
67+ task_dict = dict (task_dict )
6568
6669 # Only ask Dask to rewrite legacy graphs when we *know* the input came
6770 # from a Dask collection/HLG. This keeps classic DAGVine sexprs stable
@@ -227,36 +230,38 @@ def _should_wrap(self, obj, task_keys):
227230 """Decide whether a value should become a `TaskOutputRef`."""
228231 if isinstance (obj , self ._LEAF_TYPES ):
229232 if isinstance (obj , str ):
230- return obj in task_keys
233+ hit = obj in task_keys
234+ return hit
231235 return False
232236 try :
233- return obj in task_keys
237+ hit = obj in task_keys
238+ return hit
234239 except TypeError :
235240 return False
236241
237242 # Flatten Dask collections into the dict-of-tasks structure the rest of the
238243 # pipeline expects. DAGVine clients often hand us a dict like
239244 # {"result": dask.delayed(...)}; we merge the underlying HighLevelGraphs so
240245 # `ContextGraph` sees the same dict representation C does.
241- def _dask_collections_to_task_dict (self , collection_dict ):
246+ def _dask_collections_to_task_dict (self , task_dict ):
242247 """Flatten Dask collections into the classic dict-of-task layout."""
243248 assert is_dask_collection is not None
244249 from dask .highlevelgraph import HighLevelGraph , ensure_dict
245250
246- if not isinstance (collection_dict , dict ):
251+ if not isinstance (task_dict , dict ):
247252 raise TypeError ("Input must be a dict" )
248253
249- for k , v in collection_dict .items ():
254+ for k , v in task_dict .items ():
250255 if not is_dask_collection (v ):
251256 raise TypeError (
252257 f"Input must be a dict of DaskCollection, but found { k } with type { type (v )} "
253258 )
254259
255260 if dts :
256- sub_hlgs = [v .dask for v in collection_dict .values ()]
261+ sub_hlgs = [v .dask for v in task_dict .values ()]
257262 hlg = HighLevelGraph .merge (* sub_hlgs ).to_dict ()
258263 else :
259- hlg = dask .base .collections_to_dsk (collection_dict .values ())
264+ hlg = dask .base .collections_to_dsk (task_dict .values ())
260265
261266 return ensure_dict (hlg )
262267
@@ -299,18 +304,21 @@ def _unwrap_dts_operand(self, operand, task_keys, *, parent_key=None):
299304
300305 literal_cls = getattr (dts , "Literal" , None )
301306 if literal_cls and isinstance (operand , literal_cls ):
302- return getattr (operand , "value" , None )
307+ value = getattr (operand , "value" , None )
308+ return value
303309
304310 datanode_cls = getattr (dts , "DataNode" , None )
305311 if datanode_cls and isinstance (operand , datanode_cls ):
306- return operand .value
312+ value = operand .value
313+ return value
307314
308315 nested_cls = getattr (dts , "NestedContainer" , None )
309316 if nested_cls and isinstance (operand , nested_cls ):
310317 payload = getattr (operand , "value" , None )
311318 if payload is None :
312319 payload = getattr (operand , "data" , None )
313- return self ._unwrap_dts_operand (payload , task_keys , parent_key = parent_key )
320+ value = self ._unwrap_dts_operand (payload , task_keys , parent_key = parent_key )
321+ return value
314322
315323 task_cls = getattr (dts , "Task" , None )
316324 if task_cls and isinstance (operand , task_cls ):
@@ -323,15 +331,31 @@ def _unwrap_dts_operand(self, operand, task_keys, *, parent_key=None):
323331 # Otherwise it is an inline expression. Reduce if safe, else lift.
324332 func = self ._extract_callable_from_task (operand )
325333 if func is None :
326- return self ._lift_inline_task (operand , task_keys , parent_key = parent_key )
334+ out = self ._lift_inline_task (operand , task_keys , parent_key = parent_key )
335+ return out
327336
328337 # Special-case: Dask internal identity-cast wrappers should not be called
329- # during adaptation. Reduce structurally by returning the first argument.
338+ # during adaptation. Reduce structurally by unwrapping all args and
339+ # rebuilding the requested container type. This preserves dependency
340+ # edges (critical for WCC) without executing arbitrary code.
330341 if self ._is_identity_cast_op (func ):
331342 raw_args = getattr (operand , "args" , ()) or ()
332- if not raw_args :
333- return None
334- return self ._unwrap_dts_operand (raw_args [0 ], task_keys , parent_key = parent_key )
343+ raw_kwargs = getattr (operand , "kwargs" , {}) or {}
344+ typ = raw_kwargs .get ("typ" , None )
345+
346+ values = [self ._unwrap_dts_operand (a , task_keys , parent_key = parent_key ) for a in raw_args ]
347+
348+ # Only allow safe container constructors here; otherwise lift.
349+ safe_types = (list , tuple , set , frozenset , dict )
350+ if typ in safe_types :
351+ try :
352+ casted = typ (values )
353+ except Exception :
354+ return self ._lift_inline_task (operand , task_keys , parent_key = parent_key )
355+ return casted
356+
357+ # Unknown/unsafe typ: lift so the worker executes the real op.
358+ return self ._lift_inline_task (operand , task_keys , parent_key = parent_key )
335359
336360 if self ._is_pure_value_op (func ):
337361 reduced , used_lift = self ._reduce_inline_task (operand , task_keys , parent_key = parent_key )
0 commit comments