Skip to content

Commit 5578772

Browse files
xuanyang15copybara-github
authored andcommitted
fix(auth): isolate resolved credentials in context to prevent race conditions and data leakage
Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 901957852
1 parent 7623ff1 commit 5578772

10 files changed

Lines changed: 117 additions & 116 deletions

File tree

src/google/adk/agents/invocation_context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ..apps.app import EventsCompactionConfig
2929
from ..apps.app import ResumabilityConfig
3030
from ..artifacts.base_artifact_service import BaseArtifactService
31+
from ..auth.auth_credential import AuthCredential
3132
from ..auth.credential_service.base_credential_service import BaseCredentialService
3233
from ..events.event import Event
3334
from ..memory.base_memory_service import BaseMemoryService
@@ -214,6 +215,9 @@ class InvocationContext(BaseModel):
214215
canonical_tools_cache: Optional[list[BaseTool]] = None
215216
"""The cache of canonical tools for this invocation."""
216217

218+
credential_by_key: dict[str, AuthCredential] = Field(default_factory=dict)
219+
"""The resolved credentials for this invocation, keyed by credential_key."""
220+
217221
_invocation_cost_manager: _InvocationCostManager = PrivateAttr(
218222
default_factory=_InvocationCostManager
219223
)

src/google/adk/agents/readonly_context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
if TYPE_CHECKING:
2323
from google.genai import types
2424

25+
from ..auth.auth_credential import AuthCredential
2526
from ..sessions.session import Session
2627
from .invocation_context import InvocationContext
2728
from .run_config import RunConfig
@@ -69,3 +70,7 @@ def user_id(self) -> str:
6970
def run_config(self) -> Optional[RunConfig]:
7071
"""The run config of the current invocation. READONLY field."""
7172
return self._invocation_context.run_config
73+
74+
def get_credential(self, key: str) -> Optional[AuthCredential]:
75+
"""Gets a resolved credential by key for this invocation."""
76+
return self._invocation_context.credential_by_key.get(key)

src/google/adk/evaluation/eval_metrics.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,6 @@ class BaseCriterion(BaseModel):
115115
description="The threshold to be used by the metric.",
116116
)
117117

118-
include_intermediate_responses_in_final: bool = Field(
119-
default=False,
120-
description=(
121-
"Whether to evaluate the full agent response including intermediate"
122-
" natural language text (e.g. text emitted before tool calls) in"
123-
" addition to the final response. By default, only the final"
124-
" response text is sent to the judge. When True, text from all"
125-
" intermediate invocation events is concatenated with the final"
126-
" response before evaluation. This is useful for agents that emit"
127-
" text both before and after tool calls within a single invocation."
128-
),
129-
)
130-
131118

132119
class LlmAsAJudgeCriterion(BaseCriterion):
133120
"""Criterion when using LLM-As-A-Judge metric."""

src/google/adk/evaluation/llm_as_judge_utils.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
from .common import EvalBaseModel
2727
from .eval_case import get_all_tool_calls_with_responses
2828
from .eval_case import IntermediateDataType
29-
from .eval_case import Invocation
30-
from .eval_case import InvocationEvents
3129
from .eval_metrics import RubricScore
3230
from .evaluator import EvalStatus
3331

@@ -46,38 +44,8 @@ class Label(enum.Enum):
4644

4745

4846
def get_text_from_content(
49-
content: Optional[Union[genai_types.Content, Invocation]],
50-
*,
51-
include_intermediate_responses_in_final: bool = False,
47+
content: Optional[genai_types.Content],
5248
) -> Optional[str]:
53-
"""Extracts text from a `Content` or an `Invocation`.
54-
55-
When `content` is a `Content`, returns the concatenated text of its parts.
56-
57-
When `content` is an `Invocation`, returns the text of the invocation's final
58-
response. If `include_intermediate_responses_in_final` is True, text from
59-
intermediate invocation events (e.g. natural language emitted before tool
60-
calls) is concatenated with the final response text.
61-
"""
62-
if isinstance(content, Invocation):
63-
if not include_intermediate_responses_in_final:
64-
# Flag off: revert to basic plain-Content behavior.
65-
return get_text_from_content(content.final_response)
66-
67-
parts: list[str] = []
68-
if isinstance(content.intermediate_data, InvocationEvents):
69-
# Walk intermediate events in order; collect text parts.
70-
for event in content.intermediate_data.invocation_events:
71-
text = get_text_from_content(event.content)
72-
if text:
73-
parts.append(text)
74-
# Then fetch the final response text and append it to the end.
75-
final_text = get_text_from_content(content.final_response)
76-
if final_text:
77-
parts.append(final_text)
78-
79-
return "\n".join(parts) if parts else None
80-
8149
if content and content.parts:
8250
return "\n".join([p.text for p in content.parts if p.text])
8351

src/google/adk/evaluation/rubric_based_final_response_quality_v1.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -274,18 +274,7 @@ def format_auto_rater_prompt(
274274
"""Returns the autorater prompt."""
275275
self.create_effective_rubrics_list(actual_invocation.rubrics)
276276
user_input = get_text_from_content(actual_invocation.user_content)
277-
278-
criterion = self._eval_metric.criterion
279-
include_intermediate = getattr(
280-
criterion, "include_intermediate_responses_in_final", False
281-
)
282-
final_response = (
283-
get_text_from_content(
284-
actual_invocation,
285-
include_intermediate_responses_in_final=include_intermediate,
286-
)
287-
or ""
288-
)
277+
final_response = get_text_from_content(actual_invocation.final_response)
289278

290279
rubrics_text = "\n".join([
291280
f"* {r.rubric_content.text_property}"

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,11 @@ async def _resolve_toolset_auth(
143143
if not auth_config:
144144
continue
145145

146+
auth_config_copy = auth_config.model_copy(deep=True)
146147
try:
147-
credential = await CredentialManager(auth_config).get_auth_credential(
148-
callback_context
149-
)
148+
credential = await CredentialManager(
149+
auth_config_copy
150+
).get_auth_credential(callback_context)
150151
except ValueError as e:
151152
# Validation errors from CredentialManager should be logged but not
152153
# block the flow - the toolset may still work without auth
@@ -158,14 +159,16 @@ async def _resolve_toolset_auth(
158159
credential = None
159160

160161
if credential:
161-
# Populate in-place for toolset to use in get_tools()
162-
auth_config.exchanged_auth_credential = credential
162+
# Store in invocation context to avoid data leakage and race conditions
163+
invocation_context.credential_by_key[auth_config.credential_key] = (
164+
credential
165+
)
163166
else:
164167
# Need auth - will interrupt
165168
toolset_id = (
166169
f'{TOOLSET_AUTH_CREDENTIAL_ID_PREFIX}{type(tool_union).__name__}'
167170
)
168-
pending_auth_requests[toolset_id] = auth_config
171+
pending_auth_requests[toolset_id] = auth_config_copy
169172

170173
if not pending_auth_requests:
171174
return

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,16 +203,31 @@ def __init__(
203203
)
204204
self._use_mcp_resources = use_mcp_resources
205205

206-
def _get_auth_headers(self) -> Optional[Dict[str, str]]:
206+
def _get_auth_headers(
207+
self, readonly_context: Optional[ReadonlyContext] = None
208+
) -> Optional[Dict[str, str]]:
207209
"""Build authentication headers from exchanged credential.
208210
211+
Args:
212+
readonly_context: Readonly context to get credentials from.
213+
209214
Returns:
210215
Dictionary of auth headers, or None if no auth configured.
211216
"""
212-
if not self._auth_config or not self._auth_config.exchanged_auth_credential:
217+
if not self._auth_config:
213218
return None
214219

215-
credential = self._auth_config.exchanged_auth_credential
220+
credential = None
221+
if readonly_context:
222+
credential = readonly_context.get_credential(
223+
self._auth_config.credential_key
224+
)
225+
226+
if not credential:
227+
credential = self._auth_config.exchanged_auth_credential
228+
229+
if not credential:
230+
return None
216231
headers: Optional[Dict[str, str]] = None
217232

218233
if credential.oauth2:
@@ -289,7 +304,7 @@ async def _execute_with_session(
289304
headers.update(provider_headers)
290305

291306
# Add auth headers from exchanged credential if available
292-
auth_headers = self._get_auth_headers()
307+
auth_headers = self._get_auth_headers(readonly_context)
293308
if auth_headers:
294309
headers.update(auth_headers)
295310

tests/unittests/auth/test_toolset_auth.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def mock_invocation_context(self):
110110
ctx.credential_service = None
111111
ctx.app_name = "test-app"
112112
ctx.user_id = "test-user"
113+
ctx.credential_by_key = {}
113114
return ctx
114115

115116
@pytest.fixture
@@ -154,10 +155,10 @@ async def test_toolset_without_auth_config_skipped(
154155
assert mock_invocation_context.end_invocation is False
155156

156157
@pytest.mark.asyncio
157-
async def test_toolset_with_credential_available_populates_config(
158+
async def test_toolset_with_credential_available_populates_context(
158159
self, mock_invocation_context, mock_agent
159160
):
160-
"""Test that credential is populated in auth_config when available."""
161+
"""Test that credential is stored in invocation context when available."""
161162
auth_config = create_oauth2_auth_config()
162163
toolset = MockToolset(auth_config=auth_config)
163164
mock_agent.tools = [toolset]
@@ -184,8 +185,52 @@ async def test_toolset_with_credential_available_populates_config(
184185
# No auth request events - credential was available
185186
assert len(events) == 0
186187
assert mock_invocation_context.end_invocation is False
187-
# Credential should be populated in auth_config
188-
assert auth_config.exchanged_auth_credential == mock_credential
188+
# Credential should be stored in invocation context, not auth_config
189+
assert (
190+
mock_invocation_context.credential_by_key[auth_config.credential_key]
191+
== mock_credential
192+
)
193+
assert auth_config.exchanged_auth_credential is None
194+
195+
@pytest.mark.asyncio
196+
async def test_toolset_auth_uses_copy_and_does_not_mutate_shared_config(
197+
self, mock_invocation_context, mock_agent
198+
):
199+
"""Test that _resolve_toolset_auth uses a copy and does not mutate shared config."""
200+
auth_config = create_oauth2_auth_config()
201+
toolset = MockToolset(auth_config=auth_config)
202+
mock_agent.tools = [toolset]
203+
204+
def create_mock_cm(cfg):
205+
m = AsyncMock()
206+
m._auth_config = cfg
207+
208+
async def get_cred(ctx):
209+
cfg.exchanged_auth_credential = AuthCredential(
210+
auth_type=AuthCredentialTypes.OAUTH2,
211+
oauth2=OAuth2Auth(auth_uri="https://example.com/consent"),
212+
)
213+
return None
214+
215+
m.get_auth_credential = AsyncMock(side_effect=get_cred)
216+
return m
217+
218+
with patch(
219+
"google.adk.flows.llm_flows.base_llm_flow.CredentialManager",
220+
side_effect=create_mock_cm,
221+
):
222+
events = []
223+
async for event in _resolve_toolset_auth(
224+
mock_invocation_context, mock_agent
225+
):
226+
events.append(event)
227+
228+
# Should yield one auth request event
229+
assert len(events) == 1
230+
assert mock_invocation_context.end_invocation is True
231+
232+
# The shared auth_config should NOT be mutated
233+
assert auth_config.exchanged_auth_credential is None
189234

190235
@pytest.mark.asyncio
191236
async def test_toolset_without_credential_yields_auth_event(

tests/unittests/evaluation/test_llm_as_judge_utils.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from google.adk.evaluation.app_details import AgentDetails
2020
from google.adk.evaluation.app_details import AppDetails
2121
from google.adk.evaluation.eval_case import IntermediateData
22-
from google.adk.evaluation.eval_case import Invocation
2322
from google.adk.evaluation.eval_case import InvocationEvent
2423
from google.adk.evaluation.eval_case import InvocationEvents
2524
from google.adk.evaluation.eval_rubrics import RubricScore
@@ -89,49 +88,6 @@ def test_get_text_from_content_with_mixed_parts():
8988
assert get_text_from_content(content) == "Hello\nWorld"
9089

9190

92-
def test_get_text_from_content_with_invocation_include_intermediate_responses_in_final():
93-
"""Tests get_text_from_content on an Invocation with and without the flag."""
94-
intermediate_text = "Let me check."
95-
final_response_text = "Done."
96-
invocation = Invocation(
97-
user_content=genai_types.Content(parts=[genai_types.Part(text="user")]),
98-
intermediate_data=InvocationEvents(
99-
invocation_events=[
100-
InvocationEvent(
101-
author="agent",
102-
content=genai_types.Content(
103-
parts=[genai_types.Part(text=intermediate_text)]
104-
),
105-
),
106-
InvocationEvent(
107-
author="tool",
108-
content=genai_types.Content(
109-
parts=[
110-
genai_types.Part(
111-
function_call=genai_types.FunctionCall(name="t")
112-
)
113-
]
114-
),
115-
),
116-
]
117-
),
118-
final_response=genai_types.Content(
119-
parts=[genai_types.Part(text=final_response_text)]
120-
),
121-
)
122-
123-
# Flag off (default): only the final response text is returned.
124-
assert get_text_from_content(invocation) == final_response_text
125-
126-
# Flag on: intermediate text is concatenated before the final response.
127-
assert (
128-
get_text_from_content(
129-
invocation, include_intermediate_responses_in_final=True
130-
)
131-
== f"{intermediate_text}\n{final_response_text}"
132-
)
133-
134-
13591
def test_get_eval_status_with_none_score():
13692
"""Tests get_eval_status returns NOT_EVALUATED for a None score."""
13793
assert get_eval_status(score=None, threshold=0.5) == EvalStatus.NOT_EVALUATED

tests/unittests/tools/mcp_tool/test_mcp_toolset_auth.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,32 @@ def test_get_auth_headers_api_key_non_header_logs_warning(self, caplog):
267267

268268
# Should return None for non-header API key
269269
assert headers is None
270+
271+
def test_get_auth_headers_reads_from_readonly_context(
272+
self, toolset_with_oauth2
273+
):
274+
"""Test that _get_auth_headers reads from ReadonlyContext first."""
275+
from google.adk.agents.readonly_context import ReadonlyContext
276+
277+
mock_readonly_context = Mock(spec=ReadonlyContext)
278+
mock_credential = AuthCredential(
279+
auth_type=AuthCredentialTypes.OAUTH2,
280+
oauth2=OAuth2Auth(access_token="token-from-context"),
281+
)
282+
mock_readonly_context.get_credential.return_value = mock_credential
283+
284+
# Even if exchanged_auth_credential has a different value
285+
toolset_with_oauth2._auth_config.exchanged_auth_credential = AuthCredential(
286+
auth_type=AuthCredentialTypes.OAUTH2,
287+
oauth2=OAuth2Auth(access_token="token-from-config"),
288+
)
289+
290+
headers = toolset_with_oauth2._get_auth_headers(
291+
readonly_context=mock_readonly_context
292+
)
293+
294+
assert headers is not None
295+
assert headers["Authorization"] == "Bearer token-from-context"
296+
mock_readonly_context.get_credential.assert_called_once_with(
297+
toolset_with_oauth2._auth_config.credential_key
298+
)

0 commit comments

Comments
 (0)