1313 # L2: one NPU chip
1414 w = Worker(level=2, device_id=8, platform="a2a3", runtime="tensormap_and_ringbuffer")
1515 w.init()
16- w.run(chip_callable, chip_args, block_dim=24 )
16+ w.run(chip_callable, chip_args, config )
1717 w.close()
1818
1919 # L3: multiple chips + SubWorkers, auto-discovery in init()
2020 w = Worker(level=3, device_ids=[8, 9], num_sub_workers=2,
2121 platform="a2a3", runtime="tensormap_and_ringbuffer")
22- cid = w.register(lambda: postprocess())
22+ cid = w.register(lambda args : postprocess())
2323 w.init()
2424
25- def my_orch(orch, args):
26- r = orch.submit_next_level(chip_callable, chip_args_ptr, config, outputs=[64] )
27- orch.submit_sub(cid, inputs=[r.outputs[0].ptr] )
25+ def my_orch(orch, args, cfg ):
26+ r = orch.submit_next_level(chip_callable, chip_args_ptr, cfg )
27+ orch.submit_sub(cid, sub_args )
2828
29- w.run(Task(orch= my_orch, args= my_args) )
29+ w.run(my_orch, my_args, my_config )
3030 w.close()
3131"""
3232
3333import ctypes
3434import os
3535import struct
3636import sys
37- from dataclasses import dataclass , field
3837from multiprocessing .shared_memory import SharedMemory
3938from typing import Any , Callable , Optional
4039
4140from .orchestrator import Orchestrator
4241from .task_interface import (
4342 DIST_MAILBOX_SIZE ,
43+ ChipCallConfig ,
4444 ChipWorker ,
45+ ContinuousTensor ,
46+ DataType ,
4547 DistWorker ,
48+ TaskArgs ,
4649 _ChipWorker ,
4750)
4851
49- # ---------------------------------------------------------------------------
50- # Task
51- # ---------------------------------------------------------------------------
52-
53-
54- @dataclass
55- class Task :
56- """Execution unit for Worker.run() at any level.
57-
58- For L2: call ``Worker.run(chip_callable, chip_args, config)`` directly.
59- For L3+: provide an orch function ``fn(orchestrator, args)`` that builds
60- the DAG via ``orchestrator.submit_*``.
61- """
62-
63- orch : Callable
64- args : Any = field (default = None )
65-
66-
6752# ---------------------------------------------------------------------------
6853# Unified mailbox layout (must match dist_worker_manager.h MAILBOX_OFF_*)
6954# ---------------------------------------------------------------------------
7055#
7156# One layout for both NEXT_LEVEL (chip) and SUB workers. SUB children
72- # read `callable` as a uint64 encoding the callable_id and ignore the
73- # config + args_blob region.
57+ # read `callable` as a uint64 encoding the callable_id and decode the
58+ # args_blob region to pass TaskArgs to the registered callable .
7459
7560_OFF_STATE = 0
7661_OFF_ERROR = 4
@@ -93,6 +78,35 @@ def _mailbox_addr(shm: SharedMemory) -> int:
9378 return ctypes .addressof (ctypes .c_char .from_buffer (buf ))
9479
9580
81+ def _read_args_from_mailbox (buf ) -> TaskArgs :
82+ """Decode the TaskArgs blob written by C++ write_blob from the mailbox.
83+
84+ Blob layout at _OFF_ARGS:
85+ int32 tensor_count (T), int32 scalar_count (S),
86+ ContinuousTensor[T] (40 B each), uint64_t[S] (8 B each).
87+ """
88+ base = _OFF_ARGS
89+ t_count = struct .unpack_from ("i" , buf , base )[0 ]
90+ s_count = struct .unpack_from ("i" , buf , base + 4 )[0 ]
91+
92+ args = TaskArgs ()
93+ ct_off = base + 8
94+ for i in range (t_count ):
95+ off = ct_off + i * 40
96+ data = struct .unpack_from ("Q" , buf , off )[0 ]
97+ shapes = struct .unpack_from ("5I" , buf , off + 8 )
98+ ndims = struct .unpack_from ("I" , buf , off + 28 )[0 ]
99+ dtype_val = struct .unpack_from ("B" , buf , off + 32 )[0 ]
100+ ct = ContinuousTensor .make (data , tuple (shapes [:ndims ]), DataType (dtype_val ))
101+ args .add_tensor (ct )
102+
103+ sc_off = ct_off + t_count * 40
104+ for i in range (s_count ):
105+ args .add_scalar (struct .unpack_from ("Q" , buf , sc_off + i * 8 )[0 ])
106+
107+ return args
108+
109+
96110def _sub_worker_loop (buf , registry : dict ) -> None :
97111 """Runs in forked child process. Reads unified mailbox layout."""
98112 while True :
@@ -105,7 +119,8 @@ def _sub_worker_loop(buf, registry: dict) -> None:
105119 error = 1
106120 else :
107121 try :
108- fn ()
122+ args = _read_args_from_mailbox (buf )
123+ fn (args )
109124 except Exception : # noqa: BLE001
110125 error = 2
111126 struct .pack_into ("i" , buf , _OFF_ERROR , error )
@@ -351,25 +366,26 @@ def _start_level3(self) -> None:
351366 # run — uniform entry point
352367 # ------------------------------------------------------------------
353368
354- def run (self , task_or_callable , args = None , ** kwargs ) -> None :
355- """Execute one task synchronously.
369+ def run (self , callable , args = None , config = None ) -> None :
370+ """Execute one task (L2) or one DAG (L3+) synchronously.
356371
357- L2: run(chip_callable, chip_args, block_dim=N)
358- L3: run(Task(orch=fn, args=...))
372+ callable: ChipCallable (L2) or Python orch fn (L3+)
373+ args: TaskArgs (optional)
374+ config: ChipCallConfig (optional, default-constructed if None)
359375 """
360376 assert self ._initialized , "Worker not initialized; call init() first"
377+ cfg = config if config is not None else ChipCallConfig ()
361378
362379 if self .level == 2 :
363380 assert self ._chip_worker is not None
364- self ._chip_worker .run (task_or_callable , args , ** kwargs )
381+ self ._chip_worker .run (callable , args , cfg )
365382 else :
366383 self ._start_level3 ()
367384 assert self ._orch is not None
368385 assert self ._dist_worker is not None
369- task = task_or_callable
370386 self ._orch ._scope_begin ()
371387 try :
372- task . orch (self ._orch , task . args )
388+ callable (self ._orch , args , cfg )
373389 finally :
374390 # Always release scope refs and drain so ring slots aren't
375391 # stranded when the orch fn raises mid-DAG.
0 commit comments