Skip to content

Commit 60b9073

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: handle ContextVar detach errors during task cancellation
Refactors context manager usage in asynchronous generators to gracefully handle ValueError during cleanup, preventing crashes during task cancellation. PiperOrigin-RevId: 902183602
1 parent 5578772 commit 60b9073

5 files changed

Lines changed: 161 additions & 3 deletions

File tree

src/google/adk/agents/base_agent.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import inspect
1818
import logging
19+
import sys
1920
from typing import Any
2021
from typing import AsyncGenerator
2122
from typing import Awaitable
@@ -285,7 +286,9 @@ async def run_async(
285286
Event: the events generated by the agent.
286287
"""
287288

288-
with tracer.start_as_current_span(f'invoke_agent {self.name}') as span:
289+
cm = tracer.start_as_current_span(f'invoke_agent {self.name}')
290+
span = cm.__enter__()
291+
try:
289292
ctx = self._create_invocation_context(parent_context)
290293
tracing.trace_agent_invocation(span, self, ctx)
291294
if event := await self._handle_before_agent_callback(ctx):
@@ -302,6 +305,23 @@ async def run_async(
302305

303306
if event := await self._handle_after_agent_callback(ctx):
304307
yield event
308+
except BaseException:
309+
try:
310+
cm.__exit__(*sys.exc_info())
311+
except ValueError:
312+
logger.warning(
313+
'Failed to detach context during generator cleanup, likely due to'
314+
' cancellation.'
315+
)
316+
raise
317+
else:
318+
try:
319+
cm.__exit__(None, None, None)
320+
except ValueError:
321+
logger.warning(
322+
'Failed to detach context during generator cleanup, likely due to'
323+
' cancellation.'
324+
)
305325

306326
@final
307327
async def run_live(

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import asyncio
1919
import inspect
2020
import logging
21+
import sys
2122
from typing import AsyncGenerator
2223
from typing import Optional
2324
from typing import TYPE_CHECKING
@@ -1168,7 +1169,9 @@ async def _call_llm_async(
11681169
) -> AsyncGenerator[LlmResponse, None]:
11691170

11701171
async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
1171-
with tracer.start_as_current_span('call_llm') as span:
1172+
cm = tracer.start_as_current_span('call_llm')
1173+
span = cm.__enter__()
1174+
try:
11721175
# Runs before_model_callback inside the call_llm span so
11731176
# plugins observe the same span as after/error callbacks.
11741177
if response := await self._handle_before_model_callback(
@@ -1261,6 +1264,23 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
12611264
llm_response = altered
12621265

12631266
yield llm_response
1267+
except BaseException:
1268+
try:
1269+
cm.__exit__(*sys.exc_info())
1270+
except ValueError:
1271+
logger.warning(
1272+
'Failed to detach context during generator cleanup, likely due to'
1273+
' cancellation.'
1274+
)
1275+
raise
1276+
else:
1277+
try:
1278+
cm.__exit__(None, None, None)
1279+
except ValueError:
1280+
logger.warning(
1281+
'Failed to detach context during generator cleanup, likely due to'
1282+
' cancellation.'
1283+
)
12641284

12651285
async with Aclosing(_call_llm_with_tracing()) as agen:
12661286
async for event in agen:

src/google/adk/runners.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,9 @@ async def _run_with_trace(
543543
new_message: Optional[types.Content] = None,
544544
invocation_id: Optional[str] = None,
545545
) -> AsyncGenerator[Event, None]:
546-
with tracer.start_as_current_span('invocation'):
546+
cm = tracer.start_as_current_span('invocation')
547+
span = cm.__enter__()
548+
try:
547549
session = await self._get_or_create_session(
548550
user_id=user_id,
549551
session_id=session_id,
@@ -627,6 +629,23 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
627629
self.session_service,
628630
skip_token_compaction=invocation_context.token_compaction_checked,
629631
)
632+
except BaseException:
633+
try:
634+
cm.__exit__(*sys.exc_info())
635+
except ValueError:
636+
logger.warning(
637+
'Failed to detach context during generator cleanup, likely due to'
638+
' cancellation.'
639+
)
640+
raise
641+
else:
642+
try:
643+
cm.__exit__(None, None, None)
644+
except ValueError:
645+
logger.warning(
646+
'Failed to detach context during generator cleanup, likely due to'
647+
' cancellation.'
648+
)
630649

631650
async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen:
632651
async for event in agen:

tests/unittests/tools/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,17 @@ pytest_test(
3131
"//third_party/py/pytest_asyncio",
3232
],
3333
)
34+
35+
pytest_test(
36+
name = "test_context_cancellation",
37+
srcs = ["test_context_cancellation.py"],
38+
args = [
39+
"-p",
40+
"pytest_asyncio.plugin",
41+
],
42+
deps = [
43+
"//third_party/py/google/adk",
44+
"//third_party/py/google/genai",
45+
"//third_party/py/pytest_asyncio",
46+
],
47+
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for ContextVar error handling during cancellation."""
16+
17+
import asyncio
18+
import logging
19+
from typing import AsyncGenerator
20+
from unittest import mock
21+
22+
from google.adk.agents.base_agent import BaseAgent
23+
from google.adk.agents.invocation_context import InvocationContext
24+
from google.adk.events.event import Event
25+
from google.adk.sessions.in_memory_session_service import InMemorySessionService
26+
from google.adk.telemetry.tracing import tracer
27+
from google.genai import types
28+
import pytest
29+
from typing_extensions import override
30+
31+
32+
class MinimalAgent(BaseAgent):
33+
34+
@override
35+
async def _run_async_impl(
36+
self, ctx: InvocationContext
37+
) -> AsyncGenerator[Event, None]:
38+
yield Event(
39+
author=self.name,
40+
content=types.Content(parts=[types.Part(text='Hello')]),
41+
)
42+
43+
44+
class FailingCM:
45+
46+
def __enter__(self):
47+
return mock.Mock()
48+
49+
def __exit__(self, exc_type, exc_val, exc_tb):
50+
raise ValueError('Mocked ContextVar error')
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_run_async_handles_context_var_error(
55+
caplog: pytest.LogCaptureFixture,
56+
):
57+
agent = MinimalAgent(name='test_agent')
58+
59+
with mock.patch.object(
60+
tracer, 'start_as_current_span', return_value=FailingCM()
61+
):
62+
63+
session_service = InMemorySessionService()
64+
session = await session_service.create_session(
65+
app_name='test', user_id='user'
66+
)
67+
ctx = InvocationContext(
68+
invocation_id='inv_id',
69+
agent=agent,
70+
session=session,
71+
session_service=session_service,
72+
)
73+
74+
events = []
75+
with caplog.at_level(logging.WARNING):
76+
async for event in agent.run_async(ctx):
77+
events.append(event)
78+
79+
assert len(events) == 1
80+
assert events[0].content.parts[0].text == 'Hello'
81+
82+
assert any(
83+
'Failed to detach context during generator cleanup' in record.message
84+
for record in caplog.records
85+
)

0 commit comments

Comments
 (0)