Skip to content

Commit 131de2a

Browse files
authored
Refactor: unify Worker.run(callable, args, config), delete Task (#578)
- Delete Task dataclass; Worker.run now takes (callable, args, config) directly instead of Task(orch=fn, args=...) - Orch fn signature changes from (o, args) to (o, args, cfg) — receives the ChipCallConfig passed to Worker.run - Sub callable signature changes from fn() to fn(args) — receives TaskArgs decoded from the mailbox blob via _read_args_from_mailbox - Add 3 new tests verifying args pass-through: tensor metadata, scalar values, and empty args - Update scene_test, all L3 unit tests, and L3 scene tests - Update docs: task-flow.md, distributed_level_runtime.md, worker-manager.md, roadmap.md, orchestrator.py docstring
1 parent b130c22 commit 131de2a

12 files changed

Lines changed: 261 additions & 124 deletions

File tree

docs/distributed_level_runtime.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,18 @@ See [scheduler.md](scheduler.md) for the dispatch loop and coordination.
108108
The **execution layer**. `WorkerManager` holds two pools of `WorkerThread`s
109109
(next-level pool and sub pool). Each `WorkerThread`:
110110

111-
- owns one `IWorker` (`ChipWorker`, `SubWorker`, or nested `Worker`)
111+
- owns one `IWorker` (`ChipWorker` or nested `Worker`) for next-level workers
112112
- has its own `std::thread`
113113
- runs in one of two modes:
114114
- `THREAD`: calls `worker->run(callable, view, config)` directly in-process
115115
- `PROCESS`: forks a child at init; each dispatch writes task data to a shm
116116
mailbox, the child decodes and runs
117117

118+
SUB workers are Python-only: the forked child process runs a Python loop
119+
(``_sub_worker_loop``) that reads the args blob from the mailbox, decodes it
120+
into a ``TaskArgs``, and calls the registered callable as ``fn(args)``.
121+
There is no C++ ``SubWorker`` class.
122+
118123
See [worker-manager.md](worker-manager.md) for thread/process mechanics, fork
119124
ordering, and mailbox layout. See [task-flow.md](task-flow.md) for what flows
120125
through `IWorker::run`.
@@ -177,7 +182,7 @@ w4 = Worker(level=4, child_mode=WorkerManager.Mode.THREAD)
177182
w4.add_worker(NEXT_LEVEL, w3) # w3 is an IWorker
178183
w4.init()
179184

180-
w4.run(Task(orch=my_l4_orch, task_args=..., config=...))
185+
w4.run(my_l4_orch, my_args, my_config)
181186
```
182187

183188
When L4's `WorkerThread` dispatches to L3, L3's `Worker::run` opens its own

docs/roadmap.md

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,21 @@ get if I pip install `main` today", this page.
7777
`DIST_MAX_SCOPE_DEPTH = 64`; scopes deeper than the ring depth share
7878
the innermost ring.
7979

80+
### Uniform `Worker.run(callable, args, config)`
81+
82+
- **`Task` dataclass deleted**`Worker.run` now takes
83+
`(callable, args=None, config=None)` directly. For L3+, `callable` is
84+
the orch function; for L2, it is a `ChipCallable`. `config` defaults
85+
to `ChipCallConfig()` if omitted.
86+
- **Orch fn signature is 3-param**: `def orch(o, args, cfg)` — receives
87+
the `Orchestrator`, `TaskArgs`, and `ChipCallConfig` passed to
88+
`Worker.run`.
89+
- **Sub callable signature is `fn(args)`** — registered callables now
90+
receive the `TaskArgs` decoded from the mailbox blob. The Python child
91+
loop (`_sub_worker_loop`) reads the blob at `_OFF_ARGS` and constructs
92+
a `TaskArgs` via `_read_args_from_mailbox`. Callable registry stays
93+
Python-only (`dict[int, Callable]`).
94+
8095
### Dispatch internals
8196

8297
- `IWorker::run(uint64_t callable, TaskArgsView args, ChipCallConfig cfg)`
@@ -119,14 +134,6 @@ get if I pip install `main` today", this page.
119134
`Mode = THREAD | PROCESS` (no separate fork-proxy classes). Strict-4
120135
per-worker-type ready queues already landed in PR-D-1.
121136

122-
### PR-E: uniform `Worker.run` + callable registry unification
123-
124-
- Python `Worker.run` drops the `if level==2` branch.
125-
- Callable registry moves fully into C++
126-
(`unordered_map<uint64_t, nb::object>` owned by `Worker`) so
127-
`ChipCallable` and Python `sub` callables share one lookup path.
128-
This unblocks L4+ recursion.
129-
130137
### PR-F: C++ `Worker::run(Task)` for L4+ recursion
131138

132139
- C++ `Task { OrchFn orch; TaskArgs task_args; CallConfig config; }`

docs/task-flow.md

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,20 @@ void ChipWorker::run(Callable cb, TaskArgsView view, const CallConfig &config) o
244244
245245
### `SubWorker` (Python callable leaf)
246246
247-
```cpp
248-
void SubWorker::run(Callable cb, TaskArgsView view, const CallConfig &config) override {
249-
uint64_t cid = cb; // no cast
250-
py_registry_[cid](view); // invoke Python callable with view
251-
}
247+
SubWorker execution is handled entirely in Python. The forked child process
248+
runs ``_sub_worker_loop`` which reads the args blob from the shared-memory
249+
mailbox, decodes it into a ``TaskArgs`` object, and passes it to the
250+
registered callable:
251+
252+
```python
253+
fn(args) # args: TaskArgs decoded from the mailbox blob
252254
```
253255

256+
The callable receives the same `TaskArgs` that was submitted via
257+
`orch.submit_sub(cid, args)`, with tags stripped (tags are consumed by the
258+
Orchestrator at submit time). There is no C++ `SubWorker` class — the
259+
Python child loop and callable registry are the entire implementation.
260+
254261
Child inherits the Python registry through fork COW; the registry lookup works
255262
with no IPC.
256263

@@ -333,6 +340,12 @@ slot.config ─┼─► memcpy into shm mailbox ─► child reads view ─
333340
slot.task_args ─┘ (write_blob) (read_blob)
334341
```
335342

343+
For SUB workers in PROCESS mode, the child is a Python process running
344+
``_sub_worker_loop``. The mailbox carries the same blob format, but the
345+
Python child decodes it via ``_read_args_from_mailbox`` into a ``TaskArgs``
346+
object and calls ``fn(args)`` directly — the dispatch path bypasses
347+
``IWorker`` entirely.
348+
336349
The mailbox layout, fork ordering, and child loop are in
337350
[worker-manager.md](worker-manager.md) §4.
338351

@@ -418,7 +431,7 @@ def my_l3_orch(orch3, args, cfg):
418431
def my_l4_orch(orch4, args, cfg):
419432
orch4.submit_next_level(my_l3_orch_handle, args, cfg)
420433

421-
w4.run(Task(orch=my_l4_orch, task_args=..., config=...))
434+
w4.run(my_l4_orch, my_args, my_config)
422435
```
423436

424437
L4's WorkerThread dispatches to `w3` via `IWorker::run`. `Worker::run`
@@ -463,14 +476,14 @@ w3 = Worker(level=3, child_mode=PROCESS)
463476
w3.add_worker(NEXT_LEVEL, chip_worker_0)
464477
w3.init() # fork chip_0 here
465478
466-
w3.run(Task(orch=my_orch, task_args=args, config=CallConfig(block_dim=3)))
479+
w3.run(my_orch, args, CallConfig(block_dim=3))
467480
```
468481

469482
Step-by-step (PROCESS mode, one chip worker):
470483

471484
| Step | Where | What happens |
472485
| ---- | ----- | ------------ |
473-
| 1 | parent Python | user builds `args: TaskArgs`, calls `w3.run(Task)` |
486+
| 1 | parent Python | user builds `args: TaskArgs`, calls `w3.run(my_orch, args, config)` |
474487
| 2 | `Worker::run` | `scope_begin` → call `my_orch(&orch_, args.view(), cfg)` |
475488
| 3 | `Orchestrator::submit_next_level` | `slot = ring.alloc()`; move `chip_args` into `slot.task_args`; walk tags → `tensormap.lookup(a.data)`, `tensormap.lookup(b.data)`, `tensormap.insert(c.data, slot)`; push ready |
476489
| 4 | Scheduler thread | pop `slot`; `wt = manager.pick_idle(NEXT_LEVEL)` (WT_chip_0); `wt->dispatch(slot)` |
@@ -481,7 +494,7 @@ Step-by-step (PROCESS mode, one chip worker):
481494
| 9 | chip_0 child | `run` returns; write `TASK_DONE` |
482495
| 10 | WT_chip_0 parent | see `TASK_DONE`; call `on_complete_(slot)` |
483496
| 11 | Scheduler | mark slot COMPLETED; fanout release (none in this DAG); scope_end will release scope ref |
484-
| 12 | `Worker::run` returns | user's `w3.run(Task)` returns; `c` contains result in shm, visible to user |
497+
| 12 | `Worker::run` returns | user's `w3.run(...)` returns; `c` contains result in shm, visible to user |
485498

486499
---
487500

docs/worker-manager.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,20 @@ void WorkerThread::dispatch_thread(WorkerDispatch d) {
156156
`IWorker` casts back appropriately.
157157
- `IWorker::run` dispatches polymorphically based on the actual worker type.
158158
159+
Note: SUB workers in PROCESS mode bypass `IWorker` entirely — the Python
160+
child loop (``_sub_worker_loop``) reads the args blob from the mailbox,
161+
decodes it into a ``TaskArgs``, and calls the registered callable as
162+
``fn(args)``. The C++ dispatch path writes the same mailbox format for
163+
both worker types.
164+
159165
**When is THREAD mode safe?**
160166
161167
- The IWorker implementation must be thread-safe relative to other concurrent
162168
calls and other system state
163169
- `ChipWorker` (dlsym'd runtime.so) is safe when the runtime `.so` and its
164170
device driver support concurrent use
165-
- `SubWorker` in THREAD mode is constrained by Python's GIL (all SubWorkers
166-
in the pool effectively serialize), but this is often fine for light
167-
Python callables
171+
- SUB workers run in Python child processes (PROCESS mode) where the
172+
callable receives ``TaskArgs`` as its sole argument
168173
169174
---
170175

python/simpler/orchestrator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@
1212
Orchestrator handle at init, retrieves the C++ object via ``DistWorker.get_orchestrator()``,
1313
and passes the handle to the user's orch function::
1414
15-
def my_orch(orch, args):
15+
def my_orch(orch, args, cfg):
1616
# build the args object yourself; tags drive dependency inference
1717
a = TaskArgs()
1818
a.add_tensor(make_tensor_arg(input_tensor), TensorArgType.INPUT)
1919
a.add_tensor(make_tensor_arg(output_tensor), TensorArgType.OUTPUT)
20-
orch.submit_next_level(chip_callable, a, config)
20+
orch.submit_next_level(chip_callable, a, cfg)
2121
2222
sub_args = TaskArgs()
2323
sub_args.add_tensor(make_tensor_arg(output_tensor), TensorArgType.INPUT)
2424
orch.submit_sub(cid, sub_args)
2525
26-
w.run(Task(orch=my_orch, args=my_args))
26+
w.run(my_orch, my_args, my_config)
2727
2828
Scope/drain lifecycle is managed by ``Worker.run()``; users never call those
2929
directly.

python/simpler/worker.py

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,64 +13,49 @@
1313
# L2: one NPU chip
1414
w = Worker(level=2, device_id=8, platform="a2a3", runtime="tensormap_and_ringbuffer")
1515
w.init()
16-
w.run(chip_callable, chip_args, block_dim=24)
16+
w.run(chip_callable, chip_args, config)
1717
w.close()
1818
1919
# L3: multiple chips + SubWorkers, auto-discovery in init()
2020
w = Worker(level=3, device_ids=[8, 9], num_sub_workers=2,
2121
platform="a2a3", runtime="tensormap_and_ringbuffer")
22-
cid = w.register(lambda: postprocess())
22+
cid = w.register(lambda args: postprocess())
2323
w.init()
2424
25-
def my_orch(orch, args):
26-
r = orch.submit_next_level(chip_callable, chip_args_ptr, config, outputs=[64])
27-
orch.submit_sub(cid, inputs=[r.outputs[0].ptr])
25+
def my_orch(orch, args, cfg):
26+
r = orch.submit_next_level(chip_callable, chip_args_ptr, cfg)
27+
orch.submit_sub(cid, sub_args)
2828
29-
w.run(Task(orch=my_orch, args=my_args))
29+
w.run(my_orch, my_args, my_config)
3030
w.close()
3131
"""
3232

3333
import ctypes
3434
import os
3535
import struct
3636
import sys
37-
from dataclasses import dataclass, field
3837
from multiprocessing.shared_memory import SharedMemory
3938
from typing import Any, Callable, Optional
4039

4140
from .orchestrator import Orchestrator
4241
from .task_interface import (
4342
DIST_MAILBOX_SIZE,
43+
ChipCallConfig,
4444
ChipWorker,
45+
ContinuousTensor,
46+
DataType,
4547
DistWorker,
48+
TaskArgs,
4649
_ChipWorker,
4750
)
4851

49-
# ---------------------------------------------------------------------------
50-
# Task
51-
# ---------------------------------------------------------------------------
52-
53-
54-
@dataclass
55-
class Task:
56-
"""Execution unit for Worker.run() at any level.
57-
58-
For L2: call ``Worker.run(chip_callable, chip_args, config)`` directly.
59-
For L3+: provide an orch function ``fn(orchestrator, args)`` that builds
60-
the DAG via ``orchestrator.submit_*``.
61-
"""
62-
63-
orch: Callable
64-
args: Any = field(default=None)
65-
66-
6752
# ---------------------------------------------------------------------------
6853
# Unified mailbox layout (must match dist_worker_manager.h MAILBOX_OFF_*)
6954
# ---------------------------------------------------------------------------
7055
#
7156
# One layout for both NEXT_LEVEL (chip) and SUB workers. SUB children
72-
# read `callable` as a uint64 encoding the callable_id and ignore the
73-
# config + args_blob region.
57+
# read `callable` as a uint64 encoding the callable_id and decode the
58+
# args_blob region to pass TaskArgs to the registered callable.
7459

7560
_OFF_STATE = 0
7661
_OFF_ERROR = 4
@@ -93,6 +78,35 @@ def _mailbox_addr(shm: SharedMemory) -> int:
9378
return ctypes.addressof(ctypes.c_char.from_buffer(buf))
9479

9580

81+
def _read_args_from_mailbox(buf) -> TaskArgs:
82+
"""Decode the TaskArgs blob written by C++ write_blob from the mailbox.
83+
84+
Blob layout at _OFF_ARGS:
85+
int32 tensor_count (T), int32 scalar_count (S),
86+
ContinuousTensor[T] (40 B each), uint64_t[S] (8 B each).
87+
"""
88+
base = _OFF_ARGS
89+
t_count = struct.unpack_from("i", buf, base)[0]
90+
s_count = struct.unpack_from("i", buf, base + 4)[0]
91+
92+
args = TaskArgs()
93+
ct_off = base + 8
94+
for i in range(t_count):
95+
off = ct_off + i * 40
96+
data = struct.unpack_from("Q", buf, off)[0]
97+
shapes = struct.unpack_from("5I", buf, off + 8)
98+
ndims = struct.unpack_from("I", buf, off + 28)[0]
99+
dtype_val = struct.unpack_from("B", buf, off + 32)[0]
100+
ct = ContinuousTensor.make(data, tuple(shapes[:ndims]), DataType(dtype_val))
101+
args.add_tensor(ct)
102+
103+
sc_off = ct_off + t_count * 40
104+
for i in range(s_count):
105+
args.add_scalar(struct.unpack_from("Q", buf, sc_off + i * 8)[0])
106+
107+
return args
108+
109+
96110
def _sub_worker_loop(buf, registry: dict) -> None:
97111
"""Runs in forked child process. Reads unified mailbox layout."""
98112
while True:
@@ -105,7 +119,8 @@ def _sub_worker_loop(buf, registry: dict) -> None:
105119
error = 1
106120
else:
107121
try:
108-
fn()
122+
args = _read_args_from_mailbox(buf)
123+
fn(args)
109124
except Exception: # noqa: BLE001
110125
error = 2
111126
struct.pack_into("i", buf, _OFF_ERROR, error)
@@ -351,25 +366,26 @@ def _start_level3(self) -> None:
351366
# run — uniform entry point
352367
# ------------------------------------------------------------------
353368

354-
def run(self, task_or_callable, args=None, **kwargs) -> None:
355-
"""Execute one task synchronously.
369+
def run(self, callable, args=None, config=None) -> None:
370+
"""Execute one task (L2) or one DAG (L3+) synchronously.
356371
357-
L2: run(chip_callable, chip_args, block_dim=N)
358-
L3: run(Task(orch=fn, args=...))
372+
callable: ChipCallable (L2) or Python orch fn (L3+)
373+
args: TaskArgs (optional)
374+
config: ChipCallConfig (optional, default-constructed if None)
359375
"""
360376
assert self._initialized, "Worker not initialized; call init() first"
377+
cfg = config if config is not None else ChipCallConfig()
361378

362379
if self.level == 2:
363380
assert self._chip_worker is not None
364-
self._chip_worker.run(task_or_callable, args, **kwargs)
381+
self._chip_worker.run(callable, args, cfg)
365382
else:
366383
self._start_level3()
367384
assert self._orch is not None
368385
assert self._dist_worker is not None
369-
task = task_or_callable
370386
self._orch._scope_begin()
371387
try:
372-
task.orch(self._orch, task.args)
388+
callable(self._orch, args, cfg)
373389
finally:
374390
# Always release scope refs and drain so ring slots aren't
375391
# stranded when the orch fn raises mid-DAG.

simpler_setup/scene_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -815,8 +815,6 @@ def _run_and_validate_l3(
815815
enable_profiling=False,
816816
enable_dump_tensor=False,
817817
):
818-
from simpler.worker import Task # noqa: PLC0415
819-
820818
params = case.get("params", {})
821819
config_dict = case.get("config", {})
822820

@@ -854,12 +852,13 @@ def _run_and_validate_l3(
854852
enable_dump_tensor=enable_dump_tensor,
855853
)
856854

857-
# Wrap in Task — user orch signature: (orch, callables, task_args, config)
858-
def task_orch(orch, _unused, _ns=ns, _test_args=test_args, _config=config):
855+
# Orch fn signature: (orch, args, cfg) — inner fn forwards to
856+
# the user's scene orch which takes (orch, callables, task_args, config).
857+
def task_orch(orch, _args, _cfg, _ns=ns, _test_args=test_args, _config=config):
859858
orch_fn(orch, _ns, _test_args, _config)
860859

861860
with _temporary_env(self._resolve_env()):
862-
worker.run(Task(orch=task_orch))
861+
worker.run(task_orch)
863862

864863
if not skip_golden:
865864
_compare_outputs(test_args, golden_args, all_tensor_names, self.RTOL, self.ATOL)

tests/st/a2a3/tensormap_and_ringbuffer/test_l3_dependency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
KERNELS_BASE = "../../../../examples/a2a3/tensormap_and_ringbuffer/vector_example/kernels"
2525

2626

27-
def verify():
27+
def verify(args):
2828
"""SubCallable — dependency target, runs after ChipTask completes."""
2929

3030

tests/st/a2a3/tensormap_and_ringbuffer/test_l3_group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
KERNELS_BASE = "../../../../examples/a2a3/tensormap_and_ringbuffer/vector_example/kernels"
2525

2626

27-
def verify():
27+
def verify(args):
2828
"""SubCallable — runs after group completes."""
2929

3030

0 commit comments

Comments
 (0)