diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index 324abe230b..e1486c932b 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -124,6 +124,19 @@ async def append_event(self, session: Session, event: Event) -> Event: session.events.append(event) return event + async def append_events_batch( + self, session: Session, events: list[Event] + ) -> list[Event]: + """Appends multiple events to a session. + + Subclasses may override this to add concurrency control or batching. + The default implementation appends events sequentially. + """ + results = [] + for event in events: + results.append(await self.append_event(session, event)) + return results + async def flush(self): """Flushes any buffered events. diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 8c1fdc134e..1b828dcb5d 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -18,6 +18,7 @@ import datetime import json import logging +import random import re from typing import Any from typing import Optional @@ -352,10 +353,10 @@ async def append_event(self, session: Session, event: Event) -> Event: async with self._get_api_client() as api_client: async def _do_append(cfg: dict[str, Any]): - await api_client.agent_engines.sessions.events.append( - name=( - f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}' - ), + await self._append_with_retry( + api_client, + reasoning_engine_id=reasoning_engine_id, + session_id=session.id, author=event.author, invocation_id=event.invocation_id, timestamp=datetime.datetime.fromtimestamp( @@ -373,6 +374,65 @@ async def _do_append(cfg: dict[str, Any]): await _do_append(config) return event + @override + async def append_events_batch( + self, session: Session, events: list[Event] + ) -> list[Event]: + """Appends multiple events with concurrency control to avoid 429 errors.""" + semaphore = asyncio.Semaphore(5) + + async def _append_one(event: Event) -> Event: + async with semaphore: + return await self.append_event(session, event) + + return list(await asyncio.gather(*[_append_one(e) for e in events])) + + _RETRY_MAX_ATTEMPTS = 5 + _RETRY_INITIAL_DELAY = 1.0 + _RETRY_MAX_DELAY = 30.0 + _RETRY_EXP_BASE = 2.0 + + async def _append_with_retry( + self, + api_client: Any, + *, + reasoning_engine_id: str, + session_id: str, + author: str, + invocation_id: str, + timestamp: datetime.datetime, + config: dict[str, Any], + ) -> None: + """Appends an event to the API with retry on 429 RESOURCE_EXHAUSTED.""" + delay = self._RETRY_INITIAL_DELAY + for attempt in range(self._RETRY_MAX_ATTEMPTS): + try: + await api_client.agent_engines.sessions.events.append( + name=( + f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' + ), + author=author, + invocation_id=invocation_id, + timestamp=timestamp, + config=config, + ) + return + except ClientError as e: + if e.code == 429 and attempt < self._RETRY_MAX_ATTEMPTS - 1: + jitter = random.uniform(0, delay * 0.5) + wait = min(delay + jitter, self._RETRY_MAX_DELAY) + logger.warning( + 'Rate limited (429) on append_event, attempt %d/%d.' + ' Retrying in %.1fs.', + attempt + 1, + self._RETRY_MAX_ATTEMPTS, + wait, + ) + await asyncio.sleep(wait) + delay = min(delay * self._RETRY_EXP_BASE, self._RETRY_MAX_DELAY) + else: + raise + def _get_reasoning_engine_id(self, app_name: str): if self._agent_engine_id: return self._agent_engine_id diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 02f5159a45..5163db90af 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -1650,3 +1650,26 @@ async def tracking_fn(**kwargs): finally: database_session_service._select_required_state = original_fn await service.close() + + +@pytest.mark.asyncio +async def test_append_events_batch_sequential(session_service): + """Tests that append_events_batch appends all events sequentially.""" + session = await session_service.create_session( + app_name='my_app', user_id='user' + ) + + events = [ + Event( + invocation_id=f'batch_{i}', + author='user', + content=types.Content( + role='user', parts=[types.Part(text=f'msg_{i}')] + ), + ) + for i in range(3) + ] + + results = await session_service.append_events_batch(session, events) + assert len(results) == 3 + assert len(session.events) == 3 diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index c5c9996ef5..d747475a10 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -1311,3 +1311,133 @@ class DummyModel(pydantic.BaseModel): assert appended_event.actions.compaction is not None assert appended_event.actions.compaction.start_timestamp == 1000.0 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_append_event_retries_on_429(mock_api_client_instance): + """Tests that append_event retries with backoff on 429 RESOURCE_EXHAUSTED.""" + session_service = mock_vertex_ai_session_service() + session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + + call_count = 0 + + async def side_effect(name, author, invocation_id, timestamp, config): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise ClientError( + code=429, + response_json={'message': 'RESOURCE_EXHAUSTED'}, + response=None, + ) + return await mock_api_client_instance._append_event( + name, author, invocation_id, timestamp, config + ) + + mock_api_client_instance.agent_engines.sessions.events.append.side_effect = ( + side_effect + ) + + event_to_append = Event( + invocation_id='retry_invocation', + author='model', + timestamp=1734005535.0, + content=genai_types.Content( + parts=[genai_types.Part(text='retry_content')] + ), + ) + + with mock.patch('asyncio.sleep', new_callable=mock.AsyncMock): + await session_service.append_event(session, event_to_append) + + assert call_count == 3 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_append_event_raises_after_max_retries(mock_api_client_instance): + """Tests that append_event raises after exhausting all retry attempts.""" + session_service = mock_vertex_ai_session_service() + session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + + mock_api_client_instance.agent_engines.sessions.events.append.side_effect = ( + ClientError( + code=429, + response_json={'message': 'RESOURCE_EXHAUSTED'}, + response=None, + ) + ) + + event_to_append = Event( + invocation_id='exhaust_invocation', + author='model', + timestamp=1734005536.0, + ) + + with mock.patch('asyncio.sleep', new_callable=mock.AsyncMock): + with pytest.raises(ClientError) as excinfo: + await session_service.append_event(session, event_to_append) + assert excinfo.value.code == 429 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_append_event_does_not_retry_on_non_429(mock_api_client_instance): + """Tests that non-429 ClientErrors are raised immediately without retry.""" + session_service = mock_vertex_ai_session_service() + session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + + mock_api_client_instance.agent_engines.sessions.events.append.side_effect = ( + ClientError( + code=400, + response_json={'message': 'BAD_REQUEST'}, + response=None, + ) + ) + + event_to_append = Event( + invocation_id='no_retry_invocation', + author='model', + timestamp=1734005537.0, + ) + + with pytest.raises(ClientError) as excinfo: + await session_service.append_event(session, event_to_append) + assert excinfo.value.code == 400 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_append_events_batch(mock_api_client_instance): + """Tests that append_events_batch appends multiple events.""" + session_service = mock_vertex_ai_session_service() + session = await session_service.get_session( + app_name='123', user_id='user', session_id='1' + ) + initial_event_count = len(session.events) + + events_to_append = [ + Event( + invocation_id=f'batch_{i}', + author='model', + timestamp=1734005540.0 + i, + content=genai_types.Content( + parts=[genai_types.Part(text=f'batch_content_{i}')] + ), + ) + for i in range(5) + ] + + results = await session_service.append_events_batch( + session, events_to_append + ) + + assert len(results) == 5 + assert len(session.events) == initial_event_count + 5