|
2 | 2 |
|
3 | 3 | import json |
4 | 4 | import logging |
| 5 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
5 | 6 | from typing import TYPE_CHECKING, Any, cast |
6 | 7 |
|
7 | 8 | import boto3 |
@@ -259,7 +260,21 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio |
259 | 260 | def list_messages( |
260 | 261 | self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any |
261 | 262 | ) -> list[SessionMessage]: |
262 | | - """List messages for an agent with pagination from S3.""" |
| 263 | + """List messages for an agent with pagination from S3. |
| 264 | +
|
| 265 | + Args: |
| 266 | + session_id: ID of the session |
| 267 | + agent_id: ID of the agent |
| 268 | + limit: Optional limit on number of messages to return |
| 269 | + offset: Optional offset for pagination |
| 270 | + **kwargs: Additional keyword arguments |
| 271 | +
|
| 272 | + Returns: |
| 273 | + List of SessionMessage objects, sorted by message_id. |
| 274 | +
|
| 275 | + Raises: |
| 276 | + SessionException: If S3 error occurs during message retrieval. |
| 277 | + """ |
263 | 278 | messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/" |
264 | 279 | try: |
265 | 280 | paginator = self.client.get_paginator("list_objects_v2") |
@@ -287,10 +302,38 @@ def list_messages( |
287 | 302 | else: |
288 | 303 | message_keys = message_keys[offset:] |
289 | 304 |
|
290 | | - # Load only the required message objects |
| 305 | + # Load message objects in parallel for better performance |
291 | 306 | messages: list[SessionMessage] = [] |
292 | | - for key in message_keys: |
293 | | - message_data = self._read_s3_object(key) |
| 307 | + if not message_keys: |
| 308 | + return messages |
| 309 | + |
| 310 | + # Optimize for single worker case - avoid thread pool overhead |
| 311 | + if len(message_keys) == 1: |
| 312 | + for key in message_keys: |
| 313 | + message_data = self._read_s3_object(key) |
| 314 | + if message_data: |
| 315 | + messages.append(SessionMessage.from_dict(message_data)) |
| 316 | + return messages |
| 317 | + |
| 318 | + with ThreadPoolExecutor() as executor: |
| 319 | + # Submit all read tasks |
| 320 | + future_to_key = {executor.submit(self._read_s3_object, key): key for key in message_keys} |
| 321 | + |
| 322 | + # Create a mapping from key to index to maintain order |
| 323 | + key_to_index = {key: idx for idx, key in enumerate(message_keys)} |
| 324 | + |
| 325 | + # Initialize results list with None placeholders to maintain order |
| 326 | + results: list[dict[str, Any] | None] = [None] * len(message_keys) |
| 327 | + |
| 328 | + # Process results as they complete |
| 329 | + for future in as_completed(future_to_key): |
| 330 | + key = future_to_key[future] |
| 331 | + message_data = future.result() |
| 332 | + # Store result at the correct index to maintain order |
| 333 | + results[key_to_index[key]] = message_data |
| 334 | + |
| 335 | + # Convert results to SessionMessage objects, filtering out None values |
| 336 | + for message_data in results: |
294 | 337 | if message_data: |
295 | 338 | messages.append(SessionMessage.from_dict(message_data)) |
296 | 339 |
|
|
0 commit comments