Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/google/adk/sessions/base_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,19 @@ async def append_event(self, session: Session, event: Event) -> Event:
session.events.append(event)
return event

async def append_events_batch(
self, session: Session, events: list[Event]
) -> list[Event]:
"""Appends multiple events to a session.

Subclasses may override this to add concurrency control or batching.
The default implementation appends events sequentially.
"""
results = []
for event in events:
results.append(await self.append_event(session, event))
return results

async def flush(self):
"""Flushes any buffered events.

Expand Down
68 changes: 64 additions & 4 deletions src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import datetime
import json
import logging
import random
import re
from typing import Any
from typing import Optional
Expand Down Expand Up @@ -352,10 +353,10 @@ async def append_event(self, session: Session, event: Event) -> Event:
async with self._get_api_client() as api_client:

async def _do_append(cfg: dict[str, Any]):
await api_client.agent_engines.sessions.events.append(
name=(
f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}'
),
await self._append_with_retry(
api_client,
reasoning_engine_id=reasoning_engine_id,
session_id=session.id,
author=event.author,
invocation_id=event.invocation_id,
timestamp=datetime.datetime.fromtimestamp(
Expand All @@ -373,6 +374,65 @@ async def _do_append(cfg: dict[str, Any]):
await _do_append(config)
return event

@override
async def append_events_batch(
self, session: Session, events: list[Event]
) -> list[Event]:
"""Appends multiple events with concurrency control to avoid 429 errors."""
semaphore = asyncio.Semaphore(5)

async def _append_one(event: Event) -> Event:
async with semaphore:
return await self.append_event(session, event)

return list(await asyncio.gather(*[_append_one(e) for e in events]))

_RETRY_MAX_ATTEMPTS = 5
_RETRY_INITIAL_DELAY = 1.0
_RETRY_MAX_DELAY = 30.0
_RETRY_EXP_BASE = 2.0

async def _append_with_retry(
self,
api_client: Any,
*,
reasoning_engine_id: str,
session_id: str,
author: str,
invocation_id: str,
timestamp: datetime.datetime,
config: dict[str, Any],
) -> None:
"""Appends an event to the API with retry on 429 RESOURCE_EXHAUSTED."""
delay = self._RETRY_INITIAL_DELAY
for attempt in range(self._RETRY_MAX_ATTEMPTS):
try:
await api_client.agent_engines.sessions.events.append(
name=(
f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}'
),
author=author,
invocation_id=invocation_id,
timestamp=timestamp,
config=config,
)
return
except ClientError as e:
if e.code == 429 and attempt < self._RETRY_MAX_ATTEMPTS - 1:
jitter = random.uniform(0, delay * 0.5)
wait = min(delay + jitter, self._RETRY_MAX_DELAY)
logger.warning(
'Rate limited (429) on append_event, attempt %d/%d.'
' Retrying in %.1fs.',
attempt + 1,
self._RETRY_MAX_ATTEMPTS,
wait,
)
await asyncio.sleep(wait)
delay = min(delay * self._RETRY_EXP_BASE, self._RETRY_MAX_DELAY)
else:
raise

def _get_reasoning_engine_id(self, app_name: str):
if self._agent_engine_id:
return self._agent_engine_id
Expand Down
23 changes: 23 additions & 0 deletions tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,3 +1650,26 @@ async def tracking_fn(**kwargs):
finally:
database_session_service._select_required_state = original_fn
await service.close()


@pytest.mark.asyncio
async def test_append_events_batch_sequential(session_service):
"""Tests that append_events_batch appends all events sequentially."""
session = await session_service.create_session(
app_name='my_app', user_id='user'
)

events = [
Event(
invocation_id=f'batch_{i}',
author='user',
content=types.Content(
role='user', parts=[types.Part(text=f'msg_{i}')]
),
)
for i in range(3)
]

results = await session_service.append_events_batch(session, events)
assert len(results) == 3
assert len(session.events) == 3
130 changes: 130 additions & 0 deletions tests/unittests/sessions/test_vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,3 +1311,133 @@ class DummyModel(pydantic.BaseModel):

assert appended_event.actions.compaction is not None
assert appended_event.actions.compaction.start_timestamp == 1000.0


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_append_event_retries_on_429(mock_api_client_instance):
"""Tests that append_event retries with backoff on 429 RESOURCE_EXHAUSTED."""
session_service = mock_vertex_ai_session_service()
session = await session_service.get_session(
app_name='123', user_id='user', session_id='1'
)

call_count = 0

async def side_effect(name, author, invocation_id, timestamp, config):
nonlocal call_count
call_count += 1
if call_count <= 2:
raise ClientError(
code=429,
response_json={'message': 'RESOURCE_EXHAUSTED'},
response=None,
)
return await mock_api_client_instance._append_event(
name, author, invocation_id, timestamp, config
)

mock_api_client_instance.agent_engines.sessions.events.append.side_effect = (
side_effect
)

event_to_append = Event(
invocation_id='retry_invocation',
author='model',
timestamp=1734005535.0,
content=genai_types.Content(
parts=[genai_types.Part(text='retry_content')]
),
)

with mock.patch('asyncio.sleep', new_callable=mock.AsyncMock):
await session_service.append_event(session, event_to_append)

assert call_count == 3


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_append_event_raises_after_max_retries(mock_api_client_instance):
"""Tests that append_event raises after exhausting all retry attempts."""
session_service = mock_vertex_ai_session_service()
session = await session_service.get_session(
app_name='123', user_id='user', session_id='1'
)

mock_api_client_instance.agent_engines.sessions.events.append.side_effect = (
ClientError(
code=429,
response_json={'message': 'RESOURCE_EXHAUSTED'},
response=None,
)
)

event_to_append = Event(
invocation_id='exhaust_invocation',
author='model',
timestamp=1734005536.0,
)

with mock.patch('asyncio.sleep', new_callable=mock.AsyncMock):
with pytest.raises(ClientError) as excinfo:
await session_service.append_event(session, event_to_append)
assert excinfo.value.code == 429


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_append_event_does_not_retry_on_non_429(mock_api_client_instance):
"""Tests that non-429 ClientErrors are raised immediately without retry."""
session_service = mock_vertex_ai_session_service()
session = await session_service.get_session(
app_name='123', user_id='user', session_id='1'
)

mock_api_client_instance.agent_engines.sessions.events.append.side_effect = (
ClientError(
code=400,
response_json={'message': 'BAD_REQUEST'},
response=None,
)
)

event_to_append = Event(
invocation_id='no_retry_invocation',
author='model',
timestamp=1734005537.0,
)

with pytest.raises(ClientError) as excinfo:
await session_service.append_event(session, event_to_append)
assert excinfo.value.code == 400


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_append_events_batch(mock_api_client_instance):
"""Tests that append_events_batch appends multiple events."""
session_service = mock_vertex_ai_session_service()
session = await session_service.get_session(
app_name='123', user_id='user', session_id='1'
)
initial_event_count = len(session.events)

events_to_append = [
Event(
invocation_id=f'batch_{i}',
author='model',
timestamp=1734005540.0 + i,
content=genai_types.Content(
parts=[genai_types.Part(text=f'batch_content_{i}')]
),
)
for i in range(5)
]

results = await session_service.append_events_batch(
session, events_to_append
)

assert len(results) == 5
assert len(session.events) == initial_event_count + 5
Loading