Skip to content

Commit 30b904e

Browse files
kaligautierxuanyang15
authored andcommitted
fix(tools): support regional Discovery Engine endpoints
Merge #4720 Close #4697 Co-authored-by: Xuan Yang <xygoogle@google.com> COPYBARA_INTEGRATE_REVIEW=#4720 from kaligautier:fix/global-and-regional-region ed0b18d PiperOrigin-RevId: 888728543
1 parent b3fcd8a commit 30b904e

2 files changed

Lines changed: 269 additions & 7 deletions

File tree

src/google/adk/tools/discovery_engine_search_tool.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,75 @@
3636
r'search_result_mode.*DOCUMENTS', re.IGNORECASE
3737
)
3838

39+
_DEFAULT_ENDPOINT = 'discoveryengine.googleapis.com'
40+
_GLOBAL_LOCATION = 'global'
41+
_LOCATION_PATTERN = re.compile(
42+
r'/locations/([a-z0-9-]+)(?:/|$)', flags=re.IGNORECASE
43+
)
44+
_VALID_LOCATION_PATTERN = re.compile(r'^[a-z0-9-]+$')
45+
46+
47+
def _normalize_location(location: str, location_type: str) -> str:
48+
"""Normalizes and validates a location value."""
49+
normalized_location = location.strip().lower()
50+
if not normalized_location:
51+
raise ValueError(f'{location_type} must not be empty if specified.')
52+
if not _VALID_LOCATION_PATTERN.fullmatch(normalized_location):
53+
raise ValueError(
54+
f'{location_type} must contain only letters, digits, and hyphens.'
55+
)
56+
return normalized_location
57+
58+
59+
def _extract_resource_location(resource_id: str) -> Optional[str]:
60+
"""Extracts and validates location from a resource id."""
61+
if '/locations/' not in resource_id.lower():
62+
return None
63+
64+
location_match = _LOCATION_PATTERN.search(resource_id)
65+
if not location_match:
66+
raise ValueError('Invalid location in data_store_id or search_engine_id.')
67+
return _normalize_location(location_match.group(1), 'resource location')
68+
69+
70+
def _resolve_location(resource_id: str, location: Optional[str]) -> str:
71+
"""Resolves the Discovery Engine location to use for the endpoint."""
72+
inferred_location = _extract_resource_location(resource_id)
73+
74+
if location is not None:
75+
normalized_location = _normalize_location(location, 'location')
76+
if inferred_location and normalized_location != inferred_location:
77+
raise ValueError(
78+
'location must match the location in data_store_id or '
79+
'search_engine_id.'
80+
)
81+
return normalized_location
82+
83+
if inferred_location:
84+
return inferred_location
85+
return _GLOBAL_LOCATION
86+
87+
88+
def _build_client_options(
89+
resource_id: str,
90+
quota_project_id: Optional[str],
91+
location: Optional[str],
92+
) -> Optional[client_options.ClientOptions]:
93+
"""Builds client options for Discovery Engine requests."""
94+
client_options_kwargs = {}
95+
resolved_location = _resolve_location(resource_id, location)
96+
97+
if resolved_location != _GLOBAL_LOCATION:
98+
client_options_kwargs['api_endpoint'] = (
99+
f'{resolved_location}-{_DEFAULT_ENDPOINT}'
100+
)
101+
if quota_project_id:
102+
client_options_kwargs['quota_project_id'] = quota_project_id
103+
104+
if not client_options_kwargs:
105+
return None
106+
return client_options.ClientOptions(**client_options_kwargs)
107+
39108

40109
class SearchResultMode(enum.Enum):
41110
"""Search result mode for discovery engine search."""
@@ -61,6 +130,7 @@ def __init__(
61130
max_results: Optional[int] = None,
62131
*,
63132
search_result_mode: Optional[SearchResultMode] = None,
133+
location: Optional[str] = None,
64134
):
65135
"""Initializes the DiscoveryEngineSearchTool.
66136
@@ -75,9 +145,12 @@ def __init__(
75145
filter: The filter to be applied to the search request. Default is None.
76146
max_results: The maximum number of results to return. Default is None.
77147
search_result_mode: The search result mode. When None (default),
78-
automatically detects the correct mode by trying CHUNKS first
79-
and falling back to DOCUMENTS if the datastore requires it.
80-
Set explicitly to CHUNKS or DOCUMENTS to skip auto-detection.
148+
automatically detects the correct mode by trying CHUNKS first and
149+
falling back to DOCUMENTS if the datastore requires it. Set explicitly
150+
to CHUNKS or DOCUMENTS to skip auto-detection.
151+
location: Optional endpoint location override.
152+
Examples: "global", "us", "eu". If not specified, location is inferred
153+
from `data_store_id` or `search_engine_id` and defaults to "global".
81154
"""
82155
super().__init__(self.discovery_engine_search)
83156
if (data_store_id is None and search_engine_id is None) or (
@@ -99,13 +172,15 @@ def __init__(
99172
self._filter = filter
100173
self._max_results = max_results
101174
self._search_result_mode = search_result_mode
175+
self._location = location
102176

103177
credentials, _ = google.auth.default()
104178
quota_project_id = getattr(credentials, 'quota_project_id', None)
105-
options = (
106-
client_options.ClientOptions(quota_project_id=quota_project_id)
107-
if quota_project_id
108-
else None
179+
resource_id = data_store_id or search_engine_id or ''
180+
options = _build_client_options(
181+
resource_id=resource_id,
182+
quota_project_id=quota_project_id,
183+
location=location,
109184
)
110185
self._discovery_engine_client = discoveryengine.SearchServiceClient(
111186
credentials=credentials, client_options=options

tests/unittests/tools/test_discovery_engine_search_tool.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,193 @@ def test_init_with_data_store_specs_without_search_engine_id_raises_error(
8080
data_store_id="test_data_store", data_store_specs=[{"id": "123"}]
8181
)
8282

83+
@pytest.mark.parametrize(
84+
("tool_kwargs", "expected_endpoint"),
85+
[
86+
(
87+
{
88+
"data_store_id": (
89+
"projects/test/locations/eu/collections/default_collection/"
90+
"dataStores/test_data_store"
91+
)
92+
},
93+
"eu-discoveryengine.googleapis.com",
94+
),
95+
(
96+
{
97+
"search_engine_id": (
98+
"projects/test/locations/us/collections/default_collection/"
99+
"engines/test_search_engine"
100+
)
101+
},
102+
"us-discoveryengine.googleapis.com",
103+
),
104+
(
105+
{
106+
"data_store_id": (
107+
"projects/test/locations/europe-west1/collections/"
108+
"default_collection/dataStores/test_data_store"
109+
)
110+
},
111+
"europe-west1-discoveryengine.googleapis.com",
112+
),
113+
],
114+
)
115+
@mock.patch.object(discovery_engine_search_tool, "client_options")
116+
@mock.patch.object(discoveryengine, "SearchServiceClient")
117+
def test_init_with_regional_location_uses_regional_endpoint(
118+
self,
119+
mock_search_client,
120+
mock_client_options,
121+
tool_kwargs,
122+
expected_endpoint,
123+
):
124+
"""Test initialization uses the expected regional API endpoint."""
125+
DiscoveryEngineSearchTool(**tool_kwargs)
126+
127+
mock_client_options.ClientOptions.assert_called_once_with(
128+
api_endpoint=expected_endpoint
129+
)
130+
mock_search_client.assert_called_once_with(
131+
credentials="credentials",
132+
client_options=mock_client_options.ClientOptions.return_value,
133+
)
134+
135+
@mock.patch.object(discovery_engine_search_tool, "client_options")
136+
@mock.patch.object(discoveryengine, "SearchServiceClient")
137+
def test_init_with_explicit_location_override_uses_input_location(
138+
self, mock_search_client, mock_client_options
139+
):
140+
"""Test initialization uses explicit location when resource has none."""
141+
DiscoveryEngineSearchTool(
142+
data_store_id="test_data_store",
143+
location="eu",
144+
)
145+
146+
mock_client_options.ClientOptions.assert_called_once_with(
147+
api_endpoint="eu-discoveryengine.googleapis.com"
148+
)
149+
mock_search_client.assert_called_once_with(
150+
credentials="credentials",
151+
client_options=mock_client_options.ClientOptions.return_value,
152+
)
153+
154+
@mock.patch.object(discoveryengine, "SearchServiceClient")
155+
def test_init_with_mismatched_location_raises_error(self, mock_search_client):
156+
"""Test initialization rejects mismatched location overrides."""
157+
with pytest.raises(
158+
ValueError,
159+
match=(
160+
"location must match the location in data_store_id or "
161+
"search_engine_id."
162+
),
163+
):
164+
DiscoveryEngineSearchTool(
165+
data_store_id=(
166+
"projects/test/locations/us/collections/default_collection/"
167+
"dataStores/test_data_store"
168+
),
169+
location="eu",
170+
)
171+
172+
mock_search_client.assert_not_called()
173+
174+
@mock.patch.object(discoveryengine, "SearchServiceClient")
175+
def test_init_with_empty_location_raises_error(self, mock_search_client):
176+
"""Test initialization rejects an empty location override."""
177+
with pytest.raises(
178+
ValueError, match="location must not be empty if specified."
179+
):
180+
DiscoveryEngineSearchTool(
181+
data_store_id=(
182+
"projects/test/locations/us/collections/default_collection/"
183+
"dataStores/test_data_store"
184+
),
185+
location=" ",
186+
)
187+
188+
mock_search_client.assert_not_called()
189+
190+
@mock.patch.object(discoveryengine, "SearchServiceClient")
191+
def test_init_with_invalid_override_location_raises_error(
192+
self, mock_search_client
193+
):
194+
"""Test initialization rejects invalid override location characters."""
195+
with pytest.raises(
196+
ValueError,
197+
match="location must contain only letters, digits, and hyphens.",
198+
):
199+
DiscoveryEngineSearchTool(
200+
data_store_id="test_data_store",
201+
location="attacker.com#",
202+
)
203+
204+
mock_search_client.assert_not_called()
205+
206+
@mock.patch.object(discoveryengine, "SearchServiceClient")
207+
def test_init_with_invalid_resource_location_raises_error(
208+
self, mock_search_client
209+
):
210+
"""Test initialization rejects invalid resource location characters."""
211+
with pytest.raises(
212+
ValueError,
213+
match="Invalid location in data_store_id or search_engine_id.",
214+
):
215+
DiscoveryEngineSearchTool(
216+
data_store_id=(
217+
"projects/test/locations/attacker.com#/collections/"
218+
"default_collection/dataStores/test_data_store"
219+
)
220+
)
221+
222+
mock_search_client.assert_not_called()
223+
224+
@mock.patch.object(discovery_engine_search_tool, "client_options")
225+
@mock.patch.object(discoveryengine, "SearchServiceClient")
226+
def test_init_with_global_location_keeps_default_endpoint(
227+
self, mock_search_client, mock_client_options
228+
):
229+
"""Test initialization keeps default API endpoint for global location."""
230+
DiscoveryEngineSearchTool(
231+
data_store_id=(
232+
"projects/test/locations/global/collections/default_collection/"
233+
"dataStores/test_data_store"
234+
)
235+
)
236+
237+
mock_client_options.ClientOptions.assert_not_called()
238+
mock_search_client.assert_called_once_with(
239+
credentials="credentials", client_options=None
240+
)
241+
242+
@mock.patch.object(discovery_engine_search_tool, "client_options")
243+
@mock.patch.object(discoveryengine, "SearchServiceClient")
244+
def test_init_with_regional_location_and_quota_project_id(
245+
self, mock_search_client, mock_client_options
246+
):
247+
"""Test initialization uses endpoint and quota project id together."""
248+
mock_credentials = mock.MagicMock()
249+
mock_credentials.quota_project_id = "test-quota-project"
250+
251+
with mock.patch.object(
252+
auth, "default", return_value=(mock_credentials, "project")
253+
):
254+
DiscoveryEngineSearchTool(
255+
data_store_id=(
256+
"projects/test/locations/eu/collections/default_collection/"
257+
"dataStores/test_data_store"
258+
)
259+
)
260+
261+
mock_client_options.ClientOptions.assert_called_once_with(
262+
api_endpoint="eu-discoveryengine.googleapis.com",
263+
quota_project_id="test-quota-project",
264+
)
265+
mock_search_client.assert_called_once_with(
266+
credentials=mock_credentials,
267+
client_options=mock_client_options.ClientOptions.return_value,
268+
)
269+
83270
@mock.patch.object(discovery_engine_search_tool, "client_options")
84271
@mock.patch.object(
85272
discoveryengine,

0 commit comments

Comments
 (0)