Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions src/google/adk/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import Union

from google.genai import types
from opentelemetry import context
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
Expand Down Expand Up @@ -284,22 +285,35 @@ async def run_async(
Event: the events generated by the agent.
"""

caller_ctx = context.get_current()
ctx = self._create_invocation_context(parent_context)
async with _instrumentation.record_agent_invocation(ctx, self):
if event := await self._handle_before_agent_callback(ctx):
yield event
token = context.attach(caller_ctx)
try:
yield event
finally:
context.detach(token)
Comment on lines +292 to +296
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it's doing the opposite of what's proposed in #5722 (comment) (attaching before yield, not before).

if ctx.end_invocation:
return

async with Aclosing(self._run_async_impl(ctx)) as agen:
async for event in agen:
yield event
token = context.attach(caller_ctx)
try:
yield event
finally:
context.detach(token)

if ctx.end_invocation:
return

if event := await self._handle_after_agent_callback(ctx):
yield event
token = context.attach(caller_ctx)
try:
yield event
finally:
context.detach(token)

@final
async def run_live(
Expand All @@ -316,19 +330,32 @@ async def run_live(
Event: the events generated by the agent.
"""

caller_ctx = context.get_current()
ctx = self._create_invocation_context(parent_context)
async with _instrumentation.record_agent_invocation(ctx, self):
if event := await self._handle_before_agent_callback(ctx):
yield event
token = context.attach(caller_ctx)
try:
yield event
finally:
context.detach(token)
if ctx.end_invocation:
return

async with Aclosing(self._run_live_impl(ctx)) as agen:
async for event in agen:
yield event
token = context.attach(caller_ctx)
try:
yield event
finally:
context.detach(token)

if event := await self._handle_after_agent_callback(ctx):
yield event
token = context.attach(caller_ctx)
try:
yield event
finally:
context.detach(token)

async def _run_async_impl(
self, ctx: InvocationContext
Expand Down
39 changes: 34 additions & 5 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import warnings

from google.genai import types
from opentelemetry import context

from .agents.base_agent import BaseAgent
from .agents.base_agent import BaseAgentState
Expand Down Expand Up @@ -537,10 +538,13 @@ async def run_async(
if new_message and not new_message.role:
new_message.role = 'user'

caller_ctx = context.get_current()

async def _run_with_trace(
new_message: Optional[types.Content] = None,
invocation_id: Optional[str] = None,
) -> AsyncGenerator[Event, None]:
caller_ctx_trace = context.get_current()
with tracer.start_as_current_span('invocation'):
session = await self._get_or_create_session(
user_id=user_id,
Expand Down Expand Up @@ -600,9 +604,14 @@ async def _run_with_trace(
return

async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
caller_ctx_exec = context.get_current()
async with Aclosing(ctx.agent.run_async(ctx)) as agen:
async for event in agen:
yield event
token = context.attach(caller_ctx_exec)
try:
yield event
finally:
context.detach(token)

async with Aclosing(
self._exec_with_plugin(
Expand All @@ -613,7 +622,11 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
)
) as agen:
async for event in agen:
yield event
token = context.attach(caller_ctx_trace)
try:
yield event
finally:
context.detach(token)
# Run compaction after all events are yielded from the agent.
# (We don't compact in the middle of an invocation, we only compact at
# the end of an invocation.)
Expand All @@ -630,7 +643,11 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:

async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen:
async for event in agen:
yield event
token = context.attach(caller_ctx)
try:
yield event
finally:
context.detach(token)

async def rewind_async(
self,
Expand Down Expand Up @@ -1104,6 +1121,9 @@ async def run_live(
# AUDIO by default.
if run_config.response_modalities is None:
run_config.response_modalities = ['AUDIO']

caller_ctx = context.get_current()

if session is None and (user_id is None or session_id is None):
raise ValueError(
'Either session or user_id and session_id must be provided.'
Expand Down Expand Up @@ -1135,9 +1155,14 @@ async def run_live(
)

async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
caller_ctx_exec = context.get_current()
async with Aclosing(ctx.agent.run_live(ctx)) as agen:
async for event in agen:
yield event
token = context.attach(caller_ctx_exec)
try:
yield event
finally:
context.detach(token)

async with Aclosing(
self._exec_with_plugin(
Expand All @@ -1148,7 +1173,11 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
)
) as agen:
async for event in agen:
yield event
token = context.attach(caller_ctx)
try:
yield event
finally:
context.detach(token)

def _find_agent_to_run(
self, session: Session, root_agent: BaseAgent
Expand Down
Loading