File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -633,7 +633,7 @@ async def stream_async(
633633 # Using threading.Lock instead of asyncio.Lock because run_async() creates
634634 # separate event loops in different threads
635635 lock_acquired = False
636- if self ._concurrent_invocation_mode == "throw" :
636+ if self ._concurrent_invocation_mode == ConcurrentInvocationMode . THROW :
637637 lock_acquired = self ._invocation_lock .acquire (blocking = False )
638638 if not lock_acquired :
639639 raise ConcurrencyException (
Original file line number Diff line number Diff line change 33This module defines the types used for an Agent.
44"""
55
6- from typing import Literal , TypeAlias
6+ from enum import Enum
7+ from typing import TypeAlias
78
89from .content import ContentBlock , Messages
910from .interrupt import InterruptResponseContent
1011
1112AgentInput : TypeAlias = str | list [ContentBlock ] | list [InterruptResponseContent ] | Messages | None
1213
13- ConcurrentInvocationMode = Literal ["throw" , "unsafe_reentrant" ]
14- """Mode controlling concurrent invocation behavior.
1514
16- Values:
17- throw: Raises ConcurrencyException if concurrent invocation is attempted (default).
18- unsafe_reentrant: Allows concurrent invocations without locking.
15+ class ConcurrentInvocationMode (str , Enum ):
16+ """Mode controlling concurrent invocation behavior.
1917
20- Warning:
21- The ``unsafe_reentrant`` mode makes no guarantees about resulting behavior and is
22- provided only for advanced use cases where the caller understands the risks.
23- """
18+ Values:
19+ THROW: Raises ConcurrencyException if concurrent invocation is attempted (default).
20+ UNSAFE_REENTRANT: Allows concurrent invocations without locking.
21+
22+ Warning:
23+ The ``UNSAFE_REENTRANT`` mode makes no guarantees about resulting behavior and is
24+ provided only for advanced use cases where the caller understands the risks.
25+ """
26+
27+ THROW = "throw"
28+ UNSAFE_REENTRANT = "unsafe_reentrant"
Original file line number Diff line number Diff line change @@ -2382,6 +2382,22 @@ def test_agent_concurrent_invocation_mode_stores_value():
23822382 assert agent_reentrant ._concurrent_invocation_mode == "unsafe_reentrant"
23832383
23842384
2385+ def test_agent_concurrent_invocation_mode_accepts_enum ():
2386+ """Test that concurrent_invocation_mode accepts enum values as well as strings."""
2387+ from strands .types .agent import ConcurrentInvocationMode
2388+
2389+ model = MockedModelProvider ([{"role" : "assistant" , "content" : [{"text" : "hello" }]}])
2390+
2391+ # Using enum values
2392+ agent_throw = Agent (model = model , concurrent_invocation_mode = ConcurrentInvocationMode .THROW )
2393+ assert agent_throw ._concurrent_invocation_mode == "throw"
2394+ assert agent_throw ._concurrent_invocation_mode == ConcurrentInvocationMode .THROW
2395+
2396+ agent_reentrant = Agent (model = model , concurrent_invocation_mode = ConcurrentInvocationMode .UNSAFE_REENTRANT )
2397+ assert agent_reentrant ._concurrent_invocation_mode == "unsafe_reentrant"
2398+ assert agent_reentrant ._concurrent_invocation_mode == ConcurrentInvocationMode .UNSAFE_REENTRANT
2399+
2400+
23852401@pytest .mark .asyncio
23862402async def test_agent_sequential_invocations_work ():
23872403 """Test that sequential invocations work correctly after lock is released."""
You can’t perform that action at this time.
0 commit comments