Skip to content

Commit 51cbe7b

Browse files
CrysisDeuJackYPCOnlineUnshure
authored
Add parallel reading support to S3SessionManager.list_messages() (strands-agents#1186)
Co-authored-by: Jack Yuan <jackypc@amazon.com> Co-authored-by: Nicholas Clegg <ncclegg@amazon.com>
1 parent e4bd3bc commit 51cbe7b

3 files changed

Lines changed: 86 additions & 8 deletions

File tree

src/strands/session/s3_session_manager.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import logging
5+
from concurrent.futures import ThreadPoolExecutor, as_completed
56
from typing import TYPE_CHECKING, Any, cast
67

78
import boto3
@@ -259,7 +260,21 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio
259260
def list_messages(
260261
self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any
261262
) -> list[SessionMessage]:
262-
"""List messages for an agent with pagination from S3."""
263+
"""List messages for an agent with pagination from S3.
264+
265+
Args:
266+
session_id: ID of the session
267+
agent_id: ID of the agent
268+
limit: Optional limit on number of messages to return
269+
offset: Optional offset for pagination
270+
**kwargs: Additional keyword arguments
271+
272+
Returns:
273+
List of SessionMessage objects, sorted by message_id.
274+
275+
Raises:
276+
SessionException: If S3 error occurs during message retrieval.
277+
"""
263278
messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/"
264279
try:
265280
paginator = self.client.get_paginator("list_objects_v2")
@@ -287,10 +302,38 @@ def list_messages(
287302
else:
288303
message_keys = message_keys[offset:]
289304

290-
# Load only the required message objects
305+
# Load message objects in parallel for better performance
291306
messages: list[SessionMessage] = []
292-
for key in message_keys:
293-
message_data = self._read_s3_object(key)
307+
if not message_keys:
308+
return messages
309+
310+
# Optimize for single worker case - avoid thread pool overhead
311+
if len(message_keys) == 1:
312+
for key in message_keys:
313+
message_data = self._read_s3_object(key)
314+
if message_data:
315+
messages.append(SessionMessage.from_dict(message_data))
316+
return messages
317+
318+
with ThreadPoolExecutor() as executor:
319+
# Submit all read tasks
320+
future_to_key = {executor.submit(self._read_s3_object, key): key for key in message_keys}
321+
322+
# Create a mapping from key to index to maintain order
323+
key_to_index = {key: idx for idx, key in enumerate(message_keys)}
324+
325+
# Initialize results list with None placeholders to maintain order
326+
results: list[dict[str, Any] | None] = [None] * len(message_keys)
327+
328+
# Process results as they complete
329+
for future in as_completed(future_to_key):
330+
key = future_to_key[future]
331+
message_data = future.result()
332+
# Store result at the correct index to maintain order
333+
results[key_to_index[key]] = message_data
334+
335+
# Convert results to SessionMessage objects, filtering out None values
336+
for message_data in results:
294337
if message_data:
295338
messages.append(SessionMessage.from_dict(message_data))
296339

tests/strands/models/test_bedrock.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,11 @@ def test__init__region_precedence(mock_client_method, session_cls):
201201
def test__init__with_endpoint_url(mock_client_method):
202202
"""Test that BedrockModel uses the provided endpoint_url for VPC endpoints."""
203203
custom_endpoint = "https://vpce-12345-abcde.bedrock-runtime.us-west-2.vpce.amazonaws.com"
204-
BedrockModel(endpoint_url=custom_endpoint)
205-
mock_client_method.assert_called_with(
206-
region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=custom_endpoint
207-
)
204+
with unittest.mock.patch.object(os, "environ", {}):
205+
BedrockModel(endpoint_url=custom_endpoint)
206+
mock_client_method.assert_called_with(
207+
region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=custom_endpoint
208+
)
208209

209210

210211
def test__init__with_region_and_session_raises_value_error():

tests/strands/session/test_s3_session_manager.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,40 @@ def test_list_messages_all(s3_manager, sample_session, sample_agent):
282282
assert len(result) == 5
283283

284284

285+
def test_list_messages_single_message(s3_manager, sample_session, sample_agent):
286+
"""Test listing all messages from S3."""
287+
# Create session and agent
288+
s3_manager.create_session(sample_session)
289+
s3_manager.create_agent(sample_session.session_id, sample_agent)
290+
291+
# Create single message
292+
message = SessionMessage(
293+
{
294+
"role": "user",
295+
"content": [ContentBlock(text="Single Message")],
296+
},
297+
0,
298+
)
299+
s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, message)
300+
301+
# List all messages
302+
result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id)
303+
304+
assert len(result) == 1
305+
306+
307+
def test_list_no_messages(s3_manager, sample_session, sample_agent):
308+
"""Test listing all messages from S3."""
309+
# Create session and agent
310+
s3_manager.create_session(sample_session)
311+
s3_manager.create_agent(sample_session.session_id, sample_agent)
312+
313+
# List all messages
314+
result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id)
315+
316+
assert len(result) == 0
317+
318+
285319
def test_list_messages_with_pagination(s3_manager, sample_session, sample_agent):
286320
"""Test listing messages with pagination in S3."""
287321
# Create session and agent

0 commit comments

Comments
 (0)