Skip to content

Commit fba9ac6

Browse files
committed
made taskqueue persistence optional
1 parent 26c5d3f commit fba9ac6

4 files changed

Lines changed: 298 additions & 93 deletions

File tree

docs/source/multi_agent_example.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@
3333

3434
from effectful.handlers.llm import Template, Tool
3535
from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler
36-
from effectful.handlers.llm.multi import Choreography, ChoreographyError, scatter
36+
from effectful.handlers.llm.multi import (
37+
Choreography,
38+
ChoreographyError,
39+
PersistentTaskQueue,
40+
scatter,
41+
)
3742
from effectful.handlers.llm.persistence import PersistenceHandler, PersistentAgent
3843
from effectful.ops.types import NotHandled
3944

@@ -272,7 +277,7 @@ def main() -> None:
272277
choreo = Choreography(
273278
build_project,
274279
agents=[architect, coder1, coder2, reviewer1, reviewer2],
275-
state_dir=STATE_DIR,
280+
queue=PersistentTaskQueue(STATE_DIR / "choreo_queue"),
276281
handlers=[
277282
LiteLLMProvider(model=MODEL),
278283
RetryLLMHandler(),

effectful/handlers/llm/multi.py

Lines changed: 177 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def build_codebase(
8888
choreo = Choreography(
8989
build_codebase,
9090
agents=[architect, coder, reviewer],
91-
state_dir=Path("./state"),
91+
queue=PersistentTaskQueue(Path("./state/choreo_queue")),
9292
handlers=[
9393
LiteLLMProvider(model="gpt-4o-mini"),
9494
RetryLLMHandler(),
@@ -105,7 +105,7 @@ def build_codebase(
105105
106106
Example — parallel scatter across multiple coders::
107107
108-
from effectful.handlers.llm.multi import Choreography, scatter
108+
from effectful.handlers.llm.multi import Choreography, PersistentTaskQueue, scatter
109109
110110
def build_parallel(
111111
project_spec: str,
@@ -129,7 +129,7 @@ def build_parallel(
129129
choreo = Choreography(
130130
build_parallel,
131131
agents=[architect, coder1, coder2, coder3, reviewer],
132-
state_dir=Path("./state"),
132+
queue=PersistentTaskQueue(Path("./state/choreo_queue")),
133133
handlers=[LiteLLMProvider(model="gpt-4o-mini"), RetryLLMHandler()],
134134
)
135135
# Pass coder as a list — scatter distributes across all three
@@ -142,6 +142,7 @@ def build_parallel(
142142
143143
"""
144144

145+
import abc
145146
import contextlib
146147
import json
147148
import threading
@@ -167,7 +168,168 @@ class TaskStatus(StrEnum):
167168
FAILED = "failed"
168169

169170

170-
class TaskQueue:
171+
class TaskQueue(abc.ABC):
172+
"""Abstract task queue with claim-based ownership.
173+
174+
Subclasses implement persistent (file-based) or in-memory storage.
175+
All methods are thread-safe.
176+
"""
177+
178+
@abc.abstractmethod
179+
def submit(
180+
self,
181+
task_type: str,
182+
payload: dict,
183+
task_id: str | None = None,
184+
) -> str:
185+
"""Add a new task. Returns the task ID.
186+
187+
Idempotent when *task_id* is specified: if a task with that ID
188+
already exists (in any state), the call is a no-op.
189+
"""
190+
191+
@abc.abstractmethod
192+
def claim(self, task_type: str, owner: str) -> dict | None:
193+
"""Atomically claim the next pending task of the given type.
194+
195+
Returns the task dict if one was claimed, or ``None``.
196+
"""
197+
198+
@abc.abstractmethod
199+
def claim_by_prefix(self, prefix: str, owner: str) -> dict | None:
200+
"""Claim any pending task whose ID starts with *prefix*."""
201+
202+
@abc.abstractmethod
203+
def complete(self, task_id: str, owner: str, result: Any = None) -> None:
204+
"""Mark a claimed task as done with *result*."""
205+
206+
@abc.abstractmethod
207+
def fail(self, task_id: str, owner: str, error: str) -> None:
208+
"""Mark a claimed task as failed."""
209+
210+
@abc.abstractmethod
211+
def get_result(self, task_id: str) -> Any | None:
212+
"""Return the result of a completed task, or ``None``."""
213+
214+
@abc.abstractmethod
215+
def release_stale_claims(self, owner: str) -> int:
216+
"""Release tasks claimed by *owner* back to pending.
217+
218+
Call on startup to reclaim work from a prior crashed session.
219+
"""
220+
221+
@abc.abstractmethod
222+
def pending_count(self, task_type: str | None = None) -> int:
223+
"""Count pending tasks, optionally filtered by type."""
224+
225+
@abc.abstractmethod
226+
def all_done(self) -> bool:
227+
"""``True`` if no pending or claimed tasks remain."""
228+
229+
230+
class InMemoryTaskQueue(TaskQueue):
231+
"""In-memory task queue for testing or ephemeral workflows.
232+
233+
Not crash-tolerant — all state is lost when the process exits.
234+
Thread-safe via a single lock.
235+
"""
236+
237+
def __init__(self) -> None:
238+
self._lock = threading.Lock()
239+
self._tasks: dict[str, dict] = {} # task_id -> task dict
240+
241+
def submit(
242+
self,
243+
task_type: str,
244+
payload: dict,
245+
task_id: str | None = None,
246+
) -> str:
247+
if task_id is None:
248+
task_id = str(uuid.uuid4())[:8]
249+
with self._lock:
250+
if task_id in self._tasks:
251+
return task_id
252+
self._tasks[task_id] = {
253+
"id": task_id,
254+
"type": task_type,
255+
"payload": payload,
256+
"status": TaskStatus.PENDING,
257+
"owner": "",
258+
"result": None,
259+
}
260+
return task_id
261+
262+
def claim(self, task_type: str, owner: str) -> dict | None:
263+
with self._lock:
264+
for task_id in sorted(self._tasks):
265+
task = self._tasks[task_id]
266+
if task["status"] == TaskStatus.PENDING and task["type"] == task_type:
267+
task["status"] = TaskStatus.CLAIMED
268+
task["owner"] = owner
269+
return dict(task)
270+
return None
271+
272+
def claim_by_prefix(self, prefix: str, owner: str) -> dict | None:
273+
with self._lock:
274+
for task_id in sorted(self._tasks):
275+
task = self._tasks[task_id]
276+
if task["status"] == TaskStatus.PENDING and task_id.startswith(prefix):
277+
task["status"] = TaskStatus.CLAIMED
278+
task["owner"] = owner
279+
return dict(task)
280+
return None
281+
282+
def complete(self, task_id: str, owner: str, result: Any = None) -> None:
283+
with self._lock:
284+
task = self._tasks.get(task_id)
285+
if task is None or task["status"] != TaskStatus.CLAIMED:
286+
return
287+
task["status"] = TaskStatus.DONE
288+
task["result"] = result
289+
290+
def fail(self, task_id: str, owner: str, error: str) -> None:
291+
with self._lock:
292+
task = self._tasks.get(task_id)
293+
if task is None or task["status"] != TaskStatus.CLAIMED:
294+
return
295+
task["status"] = TaskStatus.FAILED
296+
task["result"] = {"error": error}
297+
298+
def get_result(self, task_id: str) -> Any | None:
299+
with self._lock:
300+
task = self._tasks.get(task_id)
301+
if task is not None and task["status"] == TaskStatus.DONE:
302+
return task["result"]
303+
return None
304+
305+
def release_stale_claims(self, owner: str) -> int:
306+
count = 0
307+
with self._lock:
308+
for task in self._tasks.values():
309+
if task["status"] == TaskStatus.CLAIMED and task["owner"] == owner:
310+
task["status"] = TaskStatus.PENDING
311+
task["owner"] = ""
312+
count += 1
313+
return count
314+
315+
def pending_count(self, task_type: str | None = None) -> int:
316+
with self._lock:
317+
return sum(
318+
1
319+
for t in self._tasks.values()
320+
if t["status"] == TaskStatus.PENDING
321+
and (task_type is None or t["type"] == task_type)
322+
)
323+
324+
def all_done(self) -> bool:
325+
with self._lock:
326+
return not any(
327+
t["status"] in (TaskStatus.PENDING, TaskStatus.CLAIMED)
328+
for t in self._tasks.values()
329+
)
330+
331+
332+
class PersistentTaskQueue(TaskQueue):
171333
"""File-based task queue with claim-based ownership.
172334
173335
Each task is a JSON file in *queue_dir*. Claiming a task
@@ -196,11 +358,6 @@ def submit(
196358
payload: dict,
197359
task_id: str | None = None,
198360
) -> str:
199-
"""Add a new task. Returns the task ID.
200-
201-
Idempotent when *task_id* is specified: if a task with that ID
202-
already exists (in any state), the call is a no-op.
203-
"""
204361
if task_id is None:
205362
task_id = str(uuid.uuid4())[:8]
206363
with self._lock:
@@ -219,10 +376,6 @@ def submit(
219376
return task_id
220377

221378
def claim(self, task_type: str, owner: str) -> dict | None:
222-
"""Atomically claim the next pending task of the given type.
223-
224-
Returns the task dict if one was claimed, or ``None``.
225-
"""
226379
with self._lock:
227380
for path in sorted(self.queue_dir.glob(f"*.{TaskStatus.PENDING}.json")):
228381
task = json.loads(path.read_text())
@@ -240,7 +393,6 @@ def claim(self, task_type: str, owner: str) -> dict | None:
240393
return None
241394

242395
def claim_by_prefix(self, prefix: str, owner: str) -> dict | None:
243-
"""Claim any pending task whose ID starts with *prefix*."""
244396
with self._lock:
245397
for path in sorted(self.queue_dir.glob(f"*.{TaskStatus.PENDING}.json")):
246398
fname = path.name.split(".")[0]
@@ -259,7 +411,6 @@ def claim_by_prefix(self, prefix: str, owner: str) -> dict | None:
259411
return None
260412

261413
def complete(self, task_id: str, owner: str, result: Any = None) -> None:
262-
"""Mark a claimed task as done with *result*."""
263414
claimed = self._task_path(task_id, TaskStatus.CLAIMED, owner)
264415
if not claimed.exists():
265416
return
@@ -276,7 +427,6 @@ def complete(self, task_id: str, owner: str, result: Any = None) -> None:
276427
pass
277428

278429
def fail(self, task_id: str, owner: str, error: str) -> None:
279-
"""Mark a claimed task as failed."""
280430
claimed = self._task_path(task_id, TaskStatus.CLAIMED, owner)
281431
if not claimed.exists():
282432
return
@@ -288,18 +438,13 @@ def fail(self, task_id: str, owner: str, error: str) -> None:
288438
failed.write_text(json.dumps(task, indent=2, default=str))
289439

290440
def get_result(self, task_id: str) -> Any | None:
291-
"""Return the result of a completed task, or ``None``."""
292441
done = self._task_path(task_id, TaskStatus.DONE)
293442
if done.exists():
294443
task = json.loads(done.read_text())
295444
return task.get("result")
296445
return None
297446

298447
def release_stale_claims(self, owner: str) -> int:
299-
"""Release tasks claimed by *owner* back to pending.
300-
301-
Call on startup to reclaim work from a prior crashed session.
302-
"""
303448
count = 0
304449
with self._lock:
305450
for path in self.queue_dir.glob(f"*.{TaskStatus.CLAIMED}.{owner}.json"):
@@ -313,7 +458,6 @@ def release_stale_claims(self, owner: str) -> int:
313458
return count
314459

315460
def pending_count(self, task_type: str | None = None) -> int:
316-
"""Count pending tasks, optionally filtered by type."""
317461
count = 0
318462
for path in self.queue_dir.glob(f"*.{TaskStatus.PENDING}.json"):
319463
if task_type is None:
@@ -325,7 +469,6 @@ def pending_count(self, task_type: str | None = None) -> int:
325469
return count
326470

327471
def all_done(self) -> bool:
328-
"""``True`` if no pending or claimed tasks remain."""
329472
for status in (TaskStatus.PENDING, TaskStatus.CLAIMED):
330473
if list(self.queue_dir.glob(f"*.{status}*")):
331474
return False
@@ -479,7 +622,10 @@ def _scatter(self, items: list, agent: Agent, fn: Callable) -> list:
479622
agents = agent if isinstance(agent, list) else [agent]
480623
scatter_ids = {a.__agent_id__ for a in agents}
481624

482-
# Submit one task per item
625+
# Submit one task per item. All agent threads execute this
626+
# loop, but submit() is idempotent on task_id — the
627+
# deterministic ID (step_id:index) ensures each task is
628+
# created exactly once regardless of how many threads call it.
483629
for i in range(len(items)):
484630
self._queue.submit(
485631
task_type=f"scatter-{step_id}",
@@ -512,7 +658,7 @@ def _scatter(self, items: list, agent: Agent, fn: Callable) -> list:
512658

513659

514660
class Choreography:
515-
"""Run a choreographic program with crash-tolerant endpoint projection.
661+
"""Run a choreographic program with endpoint projection.
516662
517663
Each agent gets its own thread. Template calls are routed via
518664
the :class:`TaskQueue`: the owning agent claims and executes,
@@ -524,7 +670,9 @@ class Choreography:
524670
this same function; EPP makes each thread behave
525671
differently.
526672
agents: The agents participating in the choreography.
527-
state_dir: Directory for the persistent task queue.
673+
queue: The task queue to use. Defaults to
674+
:class:`InMemoryTaskQueue` if not provided. Pass a
675+
:class:`PersistentTaskQueue` for crash tolerance.
528676
handlers: Handler instances to install per-thread beneath
529677
the EPP handler (e.g. LLM provider, retry handler,
530678
persistence handler).
@@ -536,7 +684,7 @@ class Choreography:
536684
choreo = Choreography(
537685
build_codebase,
538686
agents=[architect, coder, reviewer],
539-
state_dir=Path("./state"),
687+
queue=PersistentTaskQueue(Path("./state/choreo_queue")),
540688
handlers=[
541689
LiteLLMProvider(model="gpt-4o-mini"),
542690
RetryLLMHandler(),
@@ -555,16 +703,15 @@ def __init__(
555703
self,
556704
program: Callable[..., Any],
557705
agents: Sequence[Agent],
558-
state_dir: Path,
706+
queue: TaskQueue | None = None,
559707
handlers: Sequence[Interpretation | ObjectInterpretation] | None = None,
560708
poll_interval: float = 0.1,
561709
) -> None:
562710
self.program = program
563711
self.agents = list(agents)
564-
self.state_dir = Path(state_dir)
565712
self.handlers = list(handlers or [])
566713
self.poll_interval = poll_interval
567-
self._queue = TaskQueue(self.state_dir / "choreo_queue")
714+
self._queue = queue if queue is not None else InMemoryTaskQueue()
568715

569716
@property
570717
def queue(self) -> TaskQueue:

tests/test_handlers_llm_provider.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2507,11 +2507,13 @@ def build_project(
25072507
coder1 = Coder(agent_id="coder-1")
25082508
reviewer = Reviewer(agent_id="reviewer")
25092509

2510+
from effectful.handlers.llm.multi import PersistentTaskQueue
2511+
25102512
state_dir = tmp_path / "state"
25112513
choreo = Choreography(
25122514
build_project,
25132515
agents=[architect, coder1, reviewer],
2514-
state_dir=state_dir,
2516+
queue=PersistentTaskQueue(state_dir / "choreo_queue"),
25152517
handlers=[
25162518
LiteLLMProvider(model="gpt-4o-mini", max_tokens=300),
25172519
LimitLLMCallsHandler(max_calls=15),

0 commit comments

Comments
 (0)