@@ -88,7 +88,7 @@ def build_codebase(
8888 choreo = Choreography(
8989 build_codebase,
9090 agents=[architect, coder, reviewer],
91- state_dir= Path("./state" ),
91+ queue=PersistentTaskQueue( Path("./state/choreo_queue") ),
9292 handlers=[
9393 LiteLLMProvider(model="gpt-4o-mini"),
9494 RetryLLMHandler(),
@@ -105,7 +105,7 @@ def build_codebase(
105105
106106Example — parallel scatter across multiple coders::
107107
108- from effectful.handlers.llm.multi import Choreography, scatter
108+ from effectful.handlers.llm.multi import Choreography, PersistentTaskQueue, scatter
109109
110110 def build_parallel(
111111 project_spec: str,
@@ -129,7 +129,7 @@ def build_parallel(
129129 choreo = Choreography(
130130 build_parallel,
131131 agents=[architect, coder1, coder2, coder3, reviewer],
132- state_dir= Path("./state" ),
132+ queue=PersistentTaskQueue( Path("./state/choreo_queue") ),
133133 handlers=[LiteLLMProvider(model="gpt-4o-mini"), RetryLLMHandler()],
134134 )
135135 # Pass coder as a list — scatter distributes across all three
@@ -142,6 +142,7 @@ def build_parallel(
142142
143143"""
144144
145+ import abc
145146import contextlib
146147import json
147148import threading
@@ -167,7 +168,168 @@ class TaskStatus(StrEnum):
167168 FAILED = "failed"
168169
169170
170- class TaskQueue :
171+ class TaskQueue (abc .ABC ):
172+ """Abstract task queue with claim-based ownership.
173+
174+ Subclasses implement persistent (file-based) or in-memory storage.
175+ All methods are thread-safe.
176+ """
177+
178+ @abc .abstractmethod
179+ def submit (
180+ self ,
181+ task_type : str ,
182+ payload : dict ,
183+ task_id : str | None = None ,
184+ ) -> str :
185+ """Add a new task. Returns the task ID.
186+
187+ Idempotent when *task_id* is specified: if a task with that ID
188+ already exists (in any state), the call is a no-op.
189+ """
190+
191+ @abc .abstractmethod
192+ def claim (self , task_type : str , owner : str ) -> dict | None :
193+ """Atomically claim the next pending task of the given type.
194+
195+ Returns the task dict if one was claimed, or ``None``.
196+ """
197+
198+ @abc .abstractmethod
199+ def claim_by_prefix (self , prefix : str , owner : str ) -> dict | None :
200+ """Claim any pending task whose ID starts with *prefix*."""
201+
202+ @abc .abstractmethod
203+ def complete (self , task_id : str , owner : str , result : Any = None ) -> None :
204+ """Mark a claimed task as done with *result*."""
205+
206+ @abc .abstractmethod
207+ def fail (self , task_id : str , owner : str , error : str ) -> None :
208+ """Mark a claimed task as failed."""
209+
210+ @abc .abstractmethod
211+ def get_result (self , task_id : str ) -> Any | None :
212+ """Return the result of a completed task, or ``None``."""
213+
214+ @abc .abstractmethod
215+ def release_stale_claims (self , owner : str ) -> int :
216+ """Release tasks claimed by *owner* back to pending.
217+
218+ Call on startup to reclaim work from a prior crashed session.
219+ """
220+
221+ @abc .abstractmethod
222+ def pending_count (self , task_type : str | None = None ) -> int :
223+ """Count pending tasks, optionally filtered by type."""
224+
225+ @abc .abstractmethod
226+ def all_done (self ) -> bool :
227+ """``True`` if no pending or claimed tasks remain."""
228+
229+
230+ class InMemoryTaskQueue (TaskQueue ):
231+ """In-memory task queue for testing or ephemeral workflows.
232+
233+ Not crash-tolerant — all state is lost when the process exits.
234+ Thread-safe via a single lock.
235+ """
236+
237+ def __init__ (self ) -> None :
238+ self ._lock = threading .Lock ()
239+ self ._tasks : dict [str , dict ] = {} # task_id -> task dict
240+
241+ def submit (
242+ self ,
243+ task_type : str ,
244+ payload : dict ,
245+ task_id : str | None = None ,
246+ ) -> str :
247+ if task_id is None :
248+ task_id = str (uuid .uuid4 ())[:8 ]
249+ with self ._lock :
250+ if task_id in self ._tasks :
251+ return task_id
252+ self ._tasks [task_id ] = {
253+ "id" : task_id ,
254+ "type" : task_type ,
255+ "payload" : payload ,
256+ "status" : TaskStatus .PENDING ,
257+ "owner" : "" ,
258+ "result" : None ,
259+ }
260+ return task_id
261+
262+ def claim (self , task_type : str , owner : str ) -> dict | None :
263+ with self ._lock :
264+ for task_id in sorted (self ._tasks ):
265+ task = self ._tasks [task_id ]
266+ if task ["status" ] == TaskStatus .PENDING and task ["type" ] == task_type :
267+ task ["status" ] = TaskStatus .CLAIMED
268+ task ["owner" ] = owner
269+ return dict (task )
270+ return None
271+
272+ def claim_by_prefix (self , prefix : str , owner : str ) -> dict | None :
273+ with self ._lock :
274+ for task_id in sorted (self ._tasks ):
275+ task = self ._tasks [task_id ]
276+ if task ["status" ] == TaskStatus .PENDING and task_id .startswith (prefix ):
277+ task ["status" ] = TaskStatus .CLAIMED
278+ task ["owner" ] = owner
279+ return dict (task )
280+ return None
281+
282+ def complete (self , task_id : str , owner : str , result : Any = None ) -> None :
283+ with self ._lock :
284+ task = self ._tasks .get (task_id )
285+ if task is None or task ["status" ] != TaskStatus .CLAIMED :
286+ return
287+ task ["status" ] = TaskStatus .DONE
288+ task ["result" ] = result
289+
290+ def fail (self , task_id : str , owner : str , error : str ) -> None :
291+ with self ._lock :
292+ task = self ._tasks .get (task_id )
293+ if task is None or task ["status" ] != TaskStatus .CLAIMED :
294+ return
295+ task ["status" ] = TaskStatus .FAILED
296+ task ["result" ] = {"error" : error }
297+
298+ def get_result (self , task_id : str ) -> Any | None :
299+ with self ._lock :
300+ task = self ._tasks .get (task_id )
301+ if task is not None and task ["status" ] == TaskStatus .DONE :
302+ return task ["result" ]
303+ return None
304+
305+ def release_stale_claims (self , owner : str ) -> int :
306+ count = 0
307+ with self ._lock :
308+ for task in self ._tasks .values ():
309+ if task ["status" ] == TaskStatus .CLAIMED and task ["owner" ] == owner :
310+ task ["status" ] = TaskStatus .PENDING
311+ task ["owner" ] = ""
312+ count += 1
313+ return count
314+
315+ def pending_count (self , task_type : str | None = None ) -> int :
316+ with self ._lock :
317+ return sum (
318+ 1
319+ for t in self ._tasks .values ()
320+ if t ["status" ] == TaskStatus .PENDING
321+ and (task_type is None or t ["type" ] == task_type )
322+ )
323+
324+ def all_done (self ) -> bool :
325+ with self ._lock :
326+ return not any (
327+ t ["status" ] in (TaskStatus .PENDING , TaskStatus .CLAIMED )
328+ for t in self ._tasks .values ()
329+ )
330+
331+
332+ class PersistentTaskQueue (TaskQueue ):
171333 """File-based task queue with claim-based ownership.
172334
173335 Each task is a JSON file in *queue_dir*. Claiming a task
@@ -196,11 +358,6 @@ def submit(
196358 payload : dict ,
197359 task_id : str | None = None ,
198360 ) -> str :
199- """Add a new task. Returns the task ID.
200-
201- Idempotent when *task_id* is specified: if a task with that ID
202- already exists (in any state), the call is a no-op.
203- """
204361 if task_id is None :
205362 task_id = str (uuid .uuid4 ())[:8 ]
206363 with self ._lock :
@@ -219,10 +376,6 @@ def submit(
219376 return task_id
220377
221378 def claim (self , task_type : str , owner : str ) -> dict | None :
222- """Atomically claim the next pending task of the given type.
223-
224- Returns the task dict if one was claimed, or ``None``.
225- """
226379 with self ._lock :
227380 for path in sorted (self .queue_dir .glob (f"*.{ TaskStatus .PENDING } .json" )):
228381 task = json .loads (path .read_text ())
@@ -240,7 +393,6 @@ def claim(self, task_type: str, owner: str) -> dict | None:
240393 return None
241394
242395 def claim_by_prefix (self , prefix : str , owner : str ) -> dict | None :
243- """Claim any pending task whose ID starts with *prefix*."""
244396 with self ._lock :
245397 for path in sorted (self .queue_dir .glob (f"*.{ TaskStatus .PENDING } .json" )):
246398 fname = path .name .split ("." )[0 ]
@@ -259,7 +411,6 @@ def claim_by_prefix(self, prefix: str, owner: str) -> dict | None:
259411 return None
260412
261413 def complete (self , task_id : str , owner : str , result : Any = None ) -> None :
262- """Mark a claimed task as done with *result*."""
263414 claimed = self ._task_path (task_id , TaskStatus .CLAIMED , owner )
264415 if not claimed .exists ():
265416 return
@@ -276,7 +427,6 @@ def complete(self, task_id: str, owner: str, result: Any = None) -> None:
276427 pass
277428
278429 def fail (self , task_id : str , owner : str , error : str ) -> None :
279- """Mark a claimed task as failed."""
280430 claimed = self ._task_path (task_id , TaskStatus .CLAIMED , owner )
281431 if not claimed .exists ():
282432 return
@@ -288,18 +438,13 @@ def fail(self, task_id: str, owner: str, error: str) -> None:
288438 failed .write_text (json .dumps (task , indent = 2 , default = str ))
289439
290440 def get_result (self , task_id : str ) -> Any | None :
291- """Return the result of a completed task, or ``None``."""
292441 done = self ._task_path (task_id , TaskStatus .DONE )
293442 if done .exists ():
294443 task = json .loads (done .read_text ())
295444 return task .get ("result" )
296445 return None
297446
298447 def release_stale_claims (self , owner : str ) -> int :
299- """Release tasks claimed by *owner* back to pending.
300-
301- Call on startup to reclaim work from a prior crashed session.
302- """
303448 count = 0
304449 with self ._lock :
305450 for path in self .queue_dir .glob (f"*.{ TaskStatus .CLAIMED } .{ owner } .json" ):
@@ -313,7 +458,6 @@ def release_stale_claims(self, owner: str) -> int:
313458 return count
314459
315460 def pending_count (self , task_type : str | None = None ) -> int :
316- """Count pending tasks, optionally filtered by type."""
317461 count = 0
318462 for path in self .queue_dir .glob (f"*.{ TaskStatus .PENDING } .json" ):
319463 if task_type is None :
@@ -325,7 +469,6 @@ def pending_count(self, task_type: str | None = None) -> int:
325469 return count
326470
327471 def all_done (self ) -> bool :
328- """``True`` if no pending or claimed tasks remain."""
329472 for status in (TaskStatus .PENDING , TaskStatus .CLAIMED ):
330473 if list (self .queue_dir .glob (f"*.{ status } *" )):
331474 return False
@@ -479,7 +622,10 @@ def _scatter(self, items: list, agent: Agent, fn: Callable) -> list:
479622 agents = agent if isinstance (agent , list ) else [agent ]
480623 scatter_ids = {a .__agent_id__ for a in agents }
481624
482- # Submit one task per item
625+ # Submit one task per item. All agent threads execute this
626+ # loop, but submit() is idempotent on task_id — the
627+ # deterministic ID (step_id:index) ensures each task is
628+ # created exactly once regardless of how many threads call it.
483629 for i in range (len (items )):
484630 self ._queue .submit (
485631 task_type = f"scatter-{ step_id } " ,
@@ -512,7 +658,7 @@ def _scatter(self, items: list, agent: Agent, fn: Callable) -> list:
512658
513659
514660class Choreography :
515- """Run a choreographic program with crash-tolerant endpoint projection.
661+ """Run a choreographic program with endpoint projection.
516662
517663 Each agent gets its own thread. Template calls are routed via
518664 the :class:`TaskQueue`: the owning agent claims and executes,
@@ -524,7 +670,9 @@ class Choreography:
524670 this same function; EPP makes each thread behave
525671 differently.
526672 agents: The agents participating in the choreography.
527- state_dir: Directory for the persistent task queue.
673+ queue: The task queue to use. Defaults to
674+ :class:`InMemoryTaskQueue` if not provided. Pass a
675+ :class:`PersistentTaskQueue` for crash tolerance.
528676 handlers: Handler instances to install per-thread beneath
529677 the EPP handler (e.g. LLM provider, retry handler,
530678 persistence handler).
@@ -536,7 +684,7 @@ class Choreography:
536684 choreo = Choreography(
537685 build_codebase,
538686 agents=[architect, coder, reviewer],
539- state_dir= Path("./state" ),
687+ queue=PersistentTaskQueue( Path("./state/choreo_queue") ),
540688 handlers=[
541689 LiteLLMProvider(model="gpt-4o-mini"),
542690 RetryLLMHandler(),
@@ -555,16 +703,15 @@ def __init__(
555703 self ,
556704 program : Callable [..., Any ],
557705 agents : Sequence [Agent ],
558- state_dir : Path ,
706+ queue : TaskQueue | None = None ,
559707 handlers : Sequence [Interpretation | ObjectInterpretation ] | None = None ,
560708 poll_interval : float = 0.1 ,
561709 ) -> None :
562710 self .program = program
563711 self .agents = list (agents )
564- self .state_dir = Path (state_dir )
565712 self .handlers = list (handlers or [])
566713 self .poll_interval = poll_interval
567- self ._queue = TaskQueue ( self . state_dir / "choreo_queue" )
714+ self ._queue = queue if queue is not None else InMemoryTaskQueue ( )
568715
569716 @property
570717 def queue (self ) -> TaskQueue :
0 commit comments