Skip to content

Commit eb03e21

Browse files
committed
fix: address feedback
1 parent 701dd29 commit eb03e21

2 files changed

Lines changed: 4 additions & 8 deletions

File tree

src/strands/agent/agent.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,6 @@ async def stream_async(
632632
# Conditionally acquire lock based on concurrent_invocation_mode
633633
# Using threading.Lock instead of asyncio.Lock because run_async() creates
634634
# separate event loops in different threads
635-
lock_acquired = False
636635
if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW:
637636
lock_acquired = self._invocation_lock.acquire(blocking=False)
638637
if not lock_acquired:
@@ -687,7 +686,7 @@ async def stream_async(
687686
raise
688687

689688
finally:
690-
if lock_acquired:
689+
if self._invocation_lock.locked():
691690
self._invocation_lock.release()
692691

693692
async def _run_loop(

tests/strands/agent/test_agent.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from strands.session.repository_session_manager import RepositorySessionManager
2727
from strands.telemetry.tracer import serialize
2828
from strands.types._events import EventLoopStopEvent, ModelStreamEvent
29+
from strands.types.agent import ConcurrentInvocationMode
2930
from strands.types.content import Messages
3031
from strands.types.exceptions import ConcurrencyException, ContextWindowOverflowException, EventLoopException
3132
from strands.types.session import Session, SessionAgent, SessionMessage, SessionType
@@ -2235,16 +2236,13 @@ def test_agent_concurrent_call_raises_exception():
22352236

22362237
results = []
22372238
errors = []
2238-
lock = threading.Lock()
22392239

22402240
def invoke():
22412241
try:
22422242
result = agent("test")
2243-
with lock:
2244-
results.append(result)
2243+
results.append(result)
22452244
except ConcurrencyException as e:
2246-
with lock:
2247-
errors.append(e)
2245+
errors.append(e)
22482246

22492247
# Start first thread and wait for it to begin streaming
22502248
t1 = threading.Thread(target=invoke)
@@ -2384,7 +2382,6 @@ def test_agent_concurrent_invocation_mode_stores_value():
23842382

23852383
def test_agent_concurrent_invocation_mode_accepts_enum():
23862384
"""Test that concurrent_invocation_mode accepts enum values as well as strings."""
2387-
from strands.types.agent import ConcurrentInvocationMode
23882385

23892386
model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}])
23902387

0 commit comments

Comments
 (0)