From 826c812c0f43988b012a8b994cb31c0e6b3442d1 Mon Sep 17 00:00:00 2001 From: Chris Kinzel Date: Sat, 16 May 2026 21:22:14 +0000 Subject: [PATCH] fix(otel): prevent contextvars leak across async generators (#5722) --- src/google/adk/agents/base_agent.py | 39 ++++++++++++++++++++++++----- src/google/adk/runners.py | 39 +++++++++++++++++++++++++---- 2 files changed, 67 insertions(+), 11 deletions(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 91fb568cd3..9faeee72ef 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -31,6 +31,7 @@ from typing import Union from google.genai import types +from opentelemetry import context from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field @@ -284,22 +285,35 @@ async def run_async( Event: the events generated by the agent. """ + caller_ctx = context.get_current() ctx = self._create_invocation_context(parent_context) async with _instrumentation.record_agent_invocation(ctx, self): if event := await self._handle_before_agent_callback(ctx): - yield event + token = context.attach(caller_ctx) + try: + yield event + finally: + context.detach(token) if ctx.end_invocation: return async with Aclosing(self._run_async_impl(ctx)) as agen: async for event in agen: - yield event + token = context.attach(caller_ctx) + try: + yield event + finally: + context.detach(token) if ctx.end_invocation: return if event := await self._handle_after_agent_callback(ctx): - yield event + token = context.attach(caller_ctx) + try: + yield event + finally: + context.detach(token) @final async def run_live( @@ -316,19 +330,32 @@ async def run_live( Event: the events generated by the agent. """ + caller_ctx = context.get_current() ctx = self._create_invocation_context(parent_context) async with _instrumentation.record_agent_invocation(ctx, self): if event := await self._handle_before_agent_callback(ctx): - yield event + token = context.attach(caller_ctx) + try: + yield event + finally: + context.detach(token) if ctx.end_invocation: return async with Aclosing(self._run_live_impl(ctx)) as agen: async for event in agen: - yield event + token = context.attach(caller_ctx) + try: + yield event + finally: + context.detach(token) if event := await self._handle_after_agent_callback(ctx): - yield event + token = context.attach(caller_ctx) + try: + yield event + finally: + context.detach(token) async def _run_async_impl( self, ctx: InvocationContext diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 850c26bbba..3d21dcb4eb 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -30,6 +30,7 @@ import warnings from google.genai import types +from opentelemetry import context from .agents.base_agent import BaseAgent from .agents.base_agent import BaseAgentState @@ -537,10 +538,13 @@ async def run_async( if new_message and not new_message.role: new_message.role = 'user' + caller_ctx = context.get_current() + async def _run_with_trace( new_message: Optional[types.Content] = None, invocation_id: Optional[str] = None, ) -> AsyncGenerator[Event, None]: + caller_ctx_trace = context.get_current() with tracer.start_as_current_span('invocation'): session = await self._get_or_create_session( user_id=user_id, @@ -600,9 +604,14 @@ async def _run_with_trace( return async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: + caller_ctx_exec = context.get_current() async with Aclosing(ctx.agent.run_async(ctx)) as agen: async for event in agen: - yield event + token = context.attach(caller_ctx_exec) + try: + yield event + finally: + context.detach(token) async with Aclosing( self._exec_with_plugin( @@ -613,7 +622,11 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: ) ) as agen: async for event in agen: - yield event + token = context.attach(caller_ctx_trace) + try: + yield event + finally: + context.detach(token) # Run compaction after all events are yielded from the agent. # (We don't compact in the middle of an invocation, we only compact at # the end of an invocation.) @@ -630,7 +643,11 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen: async for event in agen: - yield event + token = context.attach(caller_ctx) + try: + yield event + finally: + context.detach(token) async def rewind_async( self, @@ -1104,6 +1121,9 @@ async def run_live( # AUDIO by default. if run_config.response_modalities is None: run_config.response_modalities = ['AUDIO'] + + caller_ctx = context.get_current() + if session is None and (user_id is None or session_id is None): raise ValueError( 'Either session or user_id and session_id must be provided.' @@ -1135,9 +1155,14 @@ async def run_live( ) async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: + caller_ctx_exec = context.get_current() async with Aclosing(ctx.agent.run_live(ctx)) as agen: async for event in agen: - yield event + token = context.attach(caller_ctx_exec) + try: + yield event + finally: + context.detach(token) async with Aclosing( self._exec_with_plugin( @@ -1148,7 +1173,11 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: ) ) as agen: async for event in agen: - yield event + token = context.attach(caller_ctx) + try: + yield event + finally: + context.detach(token) def _find_agent_to_run( self, session: Session, root_agent: BaseAgent