Skip to content

Commit dac3f31

Browse files
committed
fix: switch to enum
1 parent 8a2695a commit dac3f31

3 files changed

Lines changed: 32 additions & 11 deletions

File tree

src/strands/agent/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ async def stream_async(
633633
# Using threading.Lock instead of asyncio.Lock because run_async() creates
634634
# separate event loops in different threads
635635
lock_acquired = False
636-
if self._concurrent_invocation_mode == "throw":
636+
if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW:
637637
lock_acquired = self._invocation_lock.acquire(blocking=False)
638638
if not lock_acquired:
639639
raise ConcurrencyException(

src/strands/types/agent.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,26 @@
33
This module defines the types used for an Agent.
44
"""
55

6-
from typing import Literal, TypeAlias
6+
from enum import Enum
7+
from typing import TypeAlias
78

89
from .content import ContentBlock, Messages
910
from .interrupt import InterruptResponseContent
1011

1112
AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponseContent] | Messages | None
1213

13-
ConcurrentInvocationMode = Literal["throw", "unsafe_reentrant"]
14-
"""Mode controlling concurrent invocation behavior.
1514

16-
Values:
17-
throw: Raises ConcurrencyException if concurrent invocation is attempted (default).
18-
unsafe_reentrant: Allows concurrent invocations without locking.
15+
class ConcurrentInvocationMode(str, Enum):
16+
"""Mode controlling concurrent invocation behavior.
1917
20-
Warning:
21-
The ``unsafe_reentrant`` mode makes no guarantees about resulting behavior and is
22-
provided only for advanced use cases where the caller understands the risks.
23-
"""
18+
Values:
19+
THROW: Raises ConcurrencyException if concurrent invocation is attempted (default).
20+
UNSAFE_REENTRANT: Allows concurrent invocations without locking.
21+
22+
Warning:
23+
The ``UNSAFE_REENTRANT`` mode makes no guarantees about resulting behavior and is
24+
provided only for advanced use cases where the caller understands the risks.
25+
"""
26+
27+
THROW = "throw"
28+
UNSAFE_REENTRANT = "unsafe_reentrant"

tests/strands/agent/test_agent.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,6 +2382,22 @@ def test_agent_concurrent_invocation_mode_stores_value():
23822382
assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant"
23832383

23842384

2385+
def test_agent_concurrent_invocation_mode_accepts_enum():
2386+
"""Test that concurrent_invocation_mode accepts enum values as well as strings."""
2387+
from strands.types.agent import ConcurrentInvocationMode
2388+
2389+
model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}])
2390+
2391+
# Using enum values
2392+
agent_throw = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.THROW)
2393+
assert agent_throw._concurrent_invocation_mode == "throw"
2394+
assert agent_throw._concurrent_invocation_mode == ConcurrentInvocationMode.THROW
2395+
2396+
agent_reentrant = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.UNSAFE_REENTRANT)
2397+
assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant"
2398+
assert agent_reentrant._concurrent_invocation_mode == ConcurrentInvocationMode.UNSAFE_REENTRANT
2399+
2400+
23852401
@pytest.mark.asyncio
23862402
async def test_agent_sequential_invocations_work():
23872403
"""Test that sequential invocations work correctly after lock is released."""

0 commit comments

Comments
 (0)