Skip to content

Commit 9bf5e4e

Browse files
committed
Merge remote-tracking branch 'origin/main' into fix/final-response-match-full-response
2 parents cc22cca + 8de1ae8 commit 9bf5e4e

4 files changed

Lines changed: 265 additions & 15 deletions

File tree

src/google/adk/telemetry/google_cloud.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414

1515
from __future__ import annotations
1616

17+
import enum
1718
import logging
1819
import os
20+
from typing import Any
21+
from typing import Callable
1922
from typing import cast
2023
from typing import Optional
2124
from typing import TYPE_CHECKING
2225

2326
import google.auth
27+
from google.auth.transport import mtls
2428
from opentelemetry.sdk._logs import LogRecordProcessor
2529
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
2630
from opentelemetry.sdk.metrics.export import MetricReader
@@ -40,6 +44,19 @@
4044
_GCP_LOG_NAME_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_DEFAULT_LOG_NAME'
4145
_DEFAULT_LOG_NAME = 'adk-otel'
4246

47+
_DEFAULT_TELEMETRY_TRACES_ENPOINT = 'https://telemetry.googleapis.com/v1/traces'
48+
_DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT = (
49+
'https://telemetry.mtls.googleapis.com/v1/traces'
50+
)
51+
52+
53+
class _MtlsEndpoint(enum.Enum):
54+
"""The mTLS endpoint setting."""
55+
56+
AUTO = 'auto'
57+
ALWAYS = 'always'
58+
NEVER = 'never'
59+
4360

4461
def get_gcp_exporters(
4562
enable_cloud_tracing: bool = False,
@@ -100,10 +117,24 @@ def _get_gcp_span_exporter(credentials: Credentials) -> SpanProcessor:
100117
from google.auth.transport.requests import AuthorizedSession
101118
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
102119

120+
session = AuthorizedSession(credentials=credentials)
121+
122+
use_client_cert = _use_client_cert_effective()
123+
if use_client_cert:
124+
client_cert_source = (
125+
mtls.default_client_cert_source()
126+
if mtls.has_default_client_cert_source()
127+
else None
128+
)
129+
session.configure_mtls_channel()
130+
endpoint = _get_api_endpoint(client_cert_source)
131+
else:
132+
endpoint = _DEFAULT_TELEMETRY_TRACES_ENPOINT
133+
103134
return BatchSpanProcessor(
104135
OTLPSpanExporter(
105-
session=AuthorizedSession(credentials=credentials),
106-
endpoint='https://telemetry.googleapis.com/v1/traces',
136+
session=session,
137+
endpoint=endpoint,
107138
)
108139
)
109140

@@ -158,3 +189,62 @@ def get_gcp_resource(project_id: Optional[str] = None) -> Resource:
158189
' GCE, GKE or CloudRun related resource attributes may be missing'
159190
)
160191
return resource
192+
193+
194+
def _get_api_endpoint(
195+
client_cert_source: Callable[[], tuple[bytes, bytes]] | None = None,
196+
) -> str:
197+
"""Returns API endpoint based on mTLS configuration and cert availability.
198+
199+
Args:
200+
client_cert_source: A callable that returns the client certificate and
201+
key, or None.
202+
203+
Returns:
204+
str: The API endpoint to be used.
205+
"""
206+
use_mtls_endpoint_str = os.getenv(
207+
'GOOGLE_API_USE_MTLS_ENDPOINT', _MtlsEndpoint.AUTO.value
208+
).lower()
209+
210+
try:
211+
use_mtls_endpoint = _MtlsEndpoint(use_mtls_endpoint_str)
212+
except ValueError:
213+
logger.warning(
214+
'Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of '
215+
'%s. Defaulting to %s.',
216+
[e.value for e in _MtlsEndpoint],
217+
_MtlsEndpoint.AUTO.value,
218+
)
219+
use_mtls_endpoint = _MtlsEndpoint.AUTO
220+
221+
if (use_mtls_endpoint is _MtlsEndpoint.ALWAYS) or (
222+
use_mtls_endpoint is _MtlsEndpoint.AUTO and client_cert_source
223+
):
224+
return _DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT
225+
226+
return _DEFAULT_TELEMETRY_TRACES_ENPOINT
227+
228+
229+
def _use_client_cert_effective() -> bool:
230+
"""Returns whether client certificate should be used for mTLS.
231+
232+
This checks if the google-auth version supports should_use_client_cert
233+
automatic mTLS enablement. Alternatively, it reads from the
234+
GOOGLE_API_USE_CLIENT_CERTIFICATE env var.
235+
236+
Returns:
237+
bool: whether client certificate should be used for mTLS.
238+
"""
239+
try:
240+
return bool(mtls.should_use_client_cert())
241+
except (ImportError, AttributeError):
242+
use_client_cert_str = os.getenv(
243+
'GOOGLE_API_USE_CLIENT_CERTIFICATE', 'false'
244+
).lower()
245+
if use_client_cert_str not in ('true', 'false'):
246+
logger.warning(
247+
'Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be'
248+
' either `true` or `false`'
249+
)
250+
return use_client_cert_str == 'true'

src/google/adk/tools/skill_toolset.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import annotations
2020

2121
import asyncio
22+
import collections
2223
import json
2324
import logging
2425
import mimetypes
@@ -915,10 +916,11 @@ def __init__(
915916
self._script_timeout = script_timeout
916917
# Needed for mid-turn reloading of skill tools.
917918
self._use_invocation_cache = False
918-
self._invocation_cache: dict[
919+
# Cache fetched remote skill definitions per turn to reduce requests to registry
920+
self._fetched_skill_cache: collections.OrderedDict[
919921
str,
920922
dict[str, models.Skill | asyncio.Future[models.Skill | None] | None],
921-
] = {}
923+
] = collections.OrderedDict()
922924
self._max_cache_turns = 16
923925

924926
self._provided_tools_by_name = {}
@@ -1023,14 +1025,13 @@ async def _get_or_fetch_skill(
10231025
return None
10241026

10251027
if invocation_id:
1026-
if invocation_id not in self._invocation_cache:
1028+
if invocation_id not in self._fetched_skill_cache:
10271029
# Enforce bounded cache (FIFO eviction)
1028-
if len(self._invocation_cache) >= self._max_cache_turns:
1029-
oldest = next(iter(self._invocation_cache))
1030-
self._invocation_cache.pop(oldest)
1031-
self._invocation_cache[invocation_id] = {}
1030+
if len(self._fetched_skill_cache) >= self._max_cache_turns:
1031+
self._fetched_skill_cache.popitem(last=False)
1032+
self._fetched_skill_cache[invocation_id] = {}
10321033

1033-
turn_cache = self._invocation_cache[invocation_id]
1034+
turn_cache = self._fetched_skill_cache[invocation_id]
10341035
if skill_name in turn_cache:
10351036
cached = turn_cache[skill_name]
10361037
if isinstance(cached, asyncio.Future):
@@ -1080,6 +1081,16 @@ async def process_llm_request(
10801081

10811082
llm_request.append_instructions(instructions)
10821083

1084+
@override
1085+
async def close(self) -> None:
1086+
"""Performs cleanup and releases resources held by the toolset."""
1087+
for turn_cache in self._fetched_skill_cache.values():
1088+
for cached in turn_cache.values():
1089+
if isinstance(cached, asyncio.Future) and not cached.done():
1090+
cached.cancel()
1091+
self._fetched_skill_cache.clear()
1092+
await super().close()
1093+
10831094

10841095
def __getattr__(name: str) -> Any:
10851096
if name == "DEFAULT_SKILL_SYSTEM_INSTRUCTION":

tests/unittests/telemetry/test_google_cloud.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,18 @@
1616
from typing import Optional
1717
from unittest import mock
1818

19+
from google.adk.telemetry import google_cloud
20+
from google.adk.telemetry.google_cloud import _DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT
21+
from google.adk.telemetry.google_cloud import _DEFAULT_TELEMETRY_TRACES_ENPOINT
22+
from google.adk.telemetry.google_cloud import _get_api_endpoint
23+
from google.adk.telemetry.google_cloud import _get_gcp_span_exporter
24+
from google.adk.telemetry.google_cloud import _use_client_cert_effective
1925
from google.adk.telemetry.google_cloud import get_gcp_exporters
2026
from google.adk.telemetry.google_cloud import get_gcp_resource
27+
import google.auth.credentials
28+
from google.auth.transport import mtls
29+
from google.auth.transport import requests
30+
from opentelemetry.exporter.otlp.proto.http import trace_exporter
2131
import pytest
2232

2333

@@ -89,3 +99,108 @@ def test_get_gcp_resource(
8999
otel_resource.attributes.get("gcp.project_id", None)
90100
== expected_project_id
91101
)
102+
103+
104+
@mock.patch.object(mtls, "should_use_client_cert", autospec=True)
105+
def test_use_client_cert_effective_from_mtls(mock_should_use):
106+
mock_should_use.return_value = True
107+
assert _use_client_cert_effective()
108+
109+
mock_should_use.return_value = False
110+
assert not _use_client_cert_effective()
111+
112+
113+
def test_use_client_cert_effective_from_env(
114+
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
115+
):
116+
with mock.patch.object(
117+
mtls,
118+
"should_use_client_cert",
119+
autospec=True,
120+
side_effect=AttributeError,
121+
):
122+
monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true")
123+
assert _use_client_cert_effective()
124+
125+
monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
126+
assert not _use_client_cert_effective()
127+
128+
# Test invalid value defaults to False
129+
monkeypatch.setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "maybe")
130+
assert not _use_client_cert_effective()
131+
assert (
132+
"Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be"
133+
" either `true` or `false`"
134+
in caplog.text
135+
)
136+
137+
138+
@pytest.mark.parametrize(
139+
"env_val, cert_source, expected",
140+
[
141+
("auto", lambda: b"cert", _DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT),
142+
("auto", None, _DEFAULT_TELEMETRY_TRACES_ENPOINT),
143+
("always", None, _DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT),
144+
("never", lambda: b"cert", _DEFAULT_TELEMETRY_TRACES_ENPOINT),
145+
("invalid", None, _DEFAULT_TELEMETRY_TRACES_ENPOINT),
146+
],
147+
)
148+
def test_get_api_endpoint(
149+
env_val,
150+
cert_source,
151+
expected,
152+
monkeypatch: pytest.MonkeyPatch,
153+
caplog: pytest.LogCaptureFixture,
154+
):
155+
monkeypatch.setenv("GOOGLE_API_USE_MTLS_ENDPOINT", env_val)
156+
if env_val == "invalid":
157+
assert _get_api_endpoint(cert_source) == expected
158+
assert (
159+
"Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be one of"
160+
in caplog.text
161+
)
162+
else:
163+
assert _get_api_endpoint(cert_source) == expected
164+
165+
166+
@mock.patch.object(requests, "AuthorizedSession", autospec=True)
167+
@mock.patch(
168+
"opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter",
169+
autospec=True,
170+
)
171+
@mock.patch(
172+
"google.adk.telemetry.google_cloud.BatchSpanProcessor", autospec=True
173+
)
174+
@mock.patch(
175+
"google.adk.telemetry.google_cloud._use_client_cert_effective",
176+
autospec=True,
177+
)
178+
@mock.patch(
179+
"google.auth.transport.mtls.has_default_client_cert_source", autospec=True
180+
)
181+
@mock.patch(
182+
"google.auth.transport.mtls.default_client_cert_source", autospec=True
183+
)
184+
def test_get_gcp_span_exporter_mtls(
185+
mock_default_cert: mock.MagicMock,
186+
mock_has_cert: mock.MagicMock,
187+
mock_use_cert: mock.MagicMock,
188+
mock_batch: mock.MagicMock,
189+
mock_exporter: mock.MagicMock,
190+
mock_session: mock.MagicMock,
191+
):
192+
credentials = mock.create_autospec(
193+
google.auth.credentials.Credentials, instance=True
194+
)
195+
mock_use_cert.return_value = True
196+
mock_has_cert.return_value = True
197+
mock_default_cert.return_value = b"cert"
198+
199+
_get_gcp_span_exporter(credentials)
200+
201+
mock_session.assert_called_once_with(credentials=credentials)
202+
mock_session.return_value.configure_mtls_channel.assert_called_once()
203+
mock_exporter.assert_called_once_with(
204+
session=mock_session.return_value,
205+
endpoint=_DEFAULT_MTLS_TELEMETRY_TRACES_ENPOINT,
206+
)

tests/unittests/tools/test_skill_toolset.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import collections
1617
import logging
1718
import sys
1819
from unittest import mock
@@ -1859,14 +1860,14 @@ async def test_turn_scoped_skill_cache_eviction(mock_registry, mock_skill1):
18591860
for i in range(16):
18601861
await toolset._get_or_fetch_skill("skill1", f"turn-{i}")
18611862

1862-
assert len(toolset._invocation_cache) == 16
1863-
assert "turn-0" in toolset._invocation_cache
1863+
assert len(toolset._fetched_skill_cache) == 16
1864+
assert "turn-0" in toolset._fetched_skill_cache
18641865

18651866
# Next turn should evict oldest (turn-0)
18661867
await toolset._get_or_fetch_skill("skill1", "turn-16")
1867-
assert len(toolset._invocation_cache) == 16
1868-
assert "turn-0" not in toolset._invocation_cache
1869-
assert "turn-1" in toolset._invocation_cache
1868+
assert len(toolset._fetched_skill_cache) == 16
1869+
assert "turn-0" not in toolset._fetched_skill_cache
1870+
assert "turn-1" in toolset._fetched_skill_cache
18701871

18711872

18721873
@pytest.mark.asyncio
@@ -1891,3 +1892,36 @@ async def delayed_get_skill(name):
18911892

18921893
# Registry should have been called exactly once
18931894
mock_registry.get_skill.assert_called_once_with(name="skill1")
1895+
1896+
1897+
def test_skill_toolset_disables_invocation_cache():
1898+
"""Verify SkillToolset disables tool invocation caching to allow dynamic tools."""
1899+
toolset = skill_toolset.SkillToolset()
1900+
assert toolset._use_invocation_cache is False
1901+
1902+
1903+
@pytest.mark.asyncio
1904+
async def test_close_cancels_futures_and_clears_cache():
1905+
# pylint: disable=protected-access
1906+
toolset = skill_toolset.SkillToolset()
1907+
1908+
# Create mock futures for testing close() behavior
1909+
loop = asyncio.get_running_loop()
1910+
fut1 = loop.create_future()
1911+
fut2 = loop.create_future()
1912+
fut2.set_result(None) # Already done future
1913+
1914+
toolset._fetched_skill_cache = collections.OrderedDict(
1915+
{
1916+
"turn1": {
1917+
"skill1": fut1,
1918+
"skill2": fut2,
1919+
}
1920+
}
1921+
)
1922+
1923+
await toolset.close()
1924+
1925+
assert fut1.cancelled()
1926+
assert not fut2.cancelled() # Done futures shouldn't/can't be cancelled
1927+
assert not toolset._fetched_skill_cache

0 commit comments

Comments
 (0)