diff --git a/src/google/adk/cli/api_server.py b/src/google/adk/cli/api_server.py index 8980997d86..7595e55f18 100644 --- a/src/google/adk/cli/api_server.py +++ b/src/google/adk/cli/api_server.py @@ -164,6 +164,50 @@ def _get_scope_header( return None +import ipaddress as _ipaddress + +_LOOPBACK_HOSTNAMES = frozenset({"localhost"}) + + +def _is_loopback_address(host: str) -> bool: + """Return True if *host* (with or without a port) refers to a loopback address. + + Handles all four forms produced by browsers and uvicorn: + - Plain IPv4: "127.0.0.1" + - IPv4 with port: "127.0.0.1:8000" + - Bracketed IPv6: "[::1]" + - Bracketed IPv6+port: "[::1]:8000" + - Plain IPv6 (scope): "::1" (ASGI server tuple value) + - Hostname: "localhost" + - Hostname with port: "localhost:8000" + """ + bare = host + if bare.startswith("["): + # Bracketed IPv6: [addr] or [addr]:port + end = bare.find("]") + if end != -1: + bare = bare[1:end] + elif bare.count(":") == 1: + # IPv4:port or hostname:port (IPv6 without brackets has > 1 colon) + bare = bare.rsplit(":", 1)[0] + if bare in _LOOPBACK_HOSTNAMES: + return True + if bare.startswith("127."): + return True + try: + return _ipaddress.ip_address(bare).is_loopback + except ValueError: + return False + + +def _get_server_host(scope: dict[str, Any]) -> Optional[str]: + """Return the host the server is actually bound to (from ASGI server port).""" + server = scope.get("server") + if server and len(server) == 2: + return str(server[0]) + return None + + def _get_request_origin(scope: dict[str, Any]) -> Optional[str]: """Compute the effective origin for the current HTTP/WebSocket request.""" forwarded = _get_scope_header(scope, b"forwarded") @@ -200,12 +244,39 @@ def _is_request_origin_allowed( allowed_origin_regex: Optional[re.Pattern[str]], has_configured_allowed_origins: bool, ) -> bool: - """Validate an Origin header against explicit config or same-origin.""" + """Validate an Origin header against explicit config or same-origin. + + DNS-rebinding protection: when the server is bound to a loopback address + (127.0.0.1 / ::1 / localhost) and no explicit allow-origins have been + configured, we additionally require that the request's Origin header also + resolves to a loopback host. This prevents a DNS-rebinding attack where + an external page temporarily resolves to 127.0.0.1 and then POSTs to the + local development server by matching its own (evil.com) origin against the + Host header it controls. + """ if has_configured_allowed_origins and _is_origin_allowed( origin, allowed_literal_origins, allowed_origin_regex ): return True + # DNS-rebinding guard: if the server is on loopback and no explicit + # allow-origins list is configured, only permit origins whose host is also + # loopback. This mirrors the protection used by the MCP go-sdk SSEHandler. + server_host = _get_server_host(scope) + if ( + not has_configured_allowed_origins + and server_host is not None + and _is_loopback_address(server_host) + ): + try: + from urllib.parse import urlparse # noqa: PLC0415 (local import OK here) + + origin_host = urlparse(origin).hostname or "" + except Exception: # pylint: disable=broad-except + return False + if not _is_loopback_address(origin_host): + return False + request_origin = _get_request_origin(scope) if request_origin is None: return False diff --git a/tests/unittests/cli/test_dns_rebinding_protection.py b/tests/unittests/cli/test_dns_rebinding_protection.py new file mode 100644 index 0000000000..74d2ef0188 --- /dev/null +++ b/tests/unittests/cli/test_dns_rebinding_protection.py @@ -0,0 +1,170 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DNS-rebinding protection in _OriginCheckMiddleware.""" + +import pytest + +from google.adk.cli.api_server import ( + _is_loopback_address, + _is_request_origin_allowed, +) + + +class TestIsLoopbackAddress: + """Unit tests for _is_loopback_address.""" + + @pytest.mark.parametrize( + "host", + [ + "127.0.0.1", + "localhost", + "::1", + "[::1]", + "127.0.0.1:8000", + "localhost:8000", + "[::1]:8000", + "127.1.2.3", # any 127.x.x.x is loopback + ], + ) + def test_loopback_hosts(self, host: str): + assert _is_loopback_address(host), f"{host!r} should be loopback" + + @pytest.mark.parametrize( + "host", + [ + "evil.com", + "0.0.0.0", + "192.168.1.1", + "10.0.0.1", + "128.0.0.1", + "", + ], + ) + def test_non_loopback_hosts(self, host: str): + assert not _is_loopback_address(host), f"{host!r} should NOT be loopback" + + +class TestDnsRebindingProtection: + """Tests that DNS-rebinding attacks are blocked when server is on loopback.""" + + def _make_scope(self, server_host: str = "127.0.0.1", host_header: str = "127.0.0.1:8000") -> dict: + """Build a minimal ASGI scope for testing.""" + return { + "type": "http", + "method": "POST", + "server": (server_host, 8000), + "headers": [ + (b"host", host_header.encode()), + ], + "scheme": "http", + } + + # --- DNS rebinding scenarios (should be BLOCKED) --- + + def test_dns_rebinding_evil_origin_loopback_server_no_configured_origins(self): + """Attacker page (evil.com) DNS-rebinds to 127.0.0.1 and sends a POST. + + Browser sends Origin: http://evil.com, Host: evil.com. + Server is bound to 127.0.0.1. + No explicit allow-origins configured. + Expected: BLOCKED. + """ + scope = self._make_scope(server_host="127.0.0.1", host_header="evil.com:8000") + result = _is_request_origin_allowed( + origin="http://evil.com", + scope=scope, + allowed_literal_origins=[], + allowed_origin_regex=None, + has_configured_allowed_origins=False, + ) + assert not result, "DNS-rebinding from evil.com should be blocked on loopback server" + + def test_dns_rebinding_localhost_server(self): + """Same attack, server bound as 'localhost'.""" + scope = self._make_scope(server_host="localhost", host_header="evil.com") + result = _is_request_origin_allowed( + origin="http://evil.com", + scope=scope, + allowed_literal_origins=[], + allowed_origin_regex=None, + has_configured_allowed_origins=False, + ) + assert not result + + def test_dns_rebinding_ipv6_loopback_server(self): + """Same attack, server bound to ::1.""" + scope = self._make_scope(server_host="::1", host_header="evil.com") + result = _is_request_origin_allowed( + origin="http://evil.com", + scope=scope, + allowed_literal_origins=[], + allowed_origin_regex=None, + has_configured_allowed_origins=False, + ) + assert not result + + # --- Legitimate same-origin requests (should be ALLOWED) --- + + def test_same_origin_localhost_allowed(self): + """Legitimate browser request from localhost UI to localhost server.""" + scope = self._make_scope(server_host="127.0.0.1", host_header="127.0.0.1:8000") + result = _is_request_origin_allowed( + origin="http://127.0.0.1:8000", + scope=scope, + allowed_literal_origins=[], + allowed_origin_regex=None, + has_configured_allowed_origins=False, + ) + assert result, "Same-origin localhost request should be allowed" + + def test_same_origin_localhost_named(self): + """Browser opens http://localhost:8000 -> requests to localhost:8000.""" + scope = self._make_scope(server_host="127.0.0.1", host_header="localhost:8000") + result = _is_request_origin_allowed( + origin="http://localhost:8000", + scope=scope, + allowed_literal_origins=[], + allowed_origin_regex=None, + has_configured_allowed_origins=False, + ) + assert result + + # --- Explicit allow-origins configured (allow-list bypasses DNS guard) --- + + def test_explicit_allowlist_overrides_dns_rebinding_guard(self): + """If the developer explicitly allows evil.com, it should be permitted.""" + scope = self._make_scope(server_host="127.0.0.1", host_header="evil.com") + result = _is_request_origin_allowed( + origin="http://evil.com", + scope=scope, + allowed_literal_origins=["http://evil.com"], + allowed_origin_regex=None, + has_configured_allowed_origins=True, + ) + assert result, "Explicitly allowed origin should still pass" + + # --- Non-loopback server (protection does not apply) --- + + def test_non_loopback_server_no_dns_guard(self): + """Server bound to 0.0.0.0 — DNS guard must not interfere with same-origin check.""" + scope = self._make_scope(server_host="0.0.0.0", host_header="example.com:8000") + result = _is_request_origin_allowed( + origin="http://example.com:8000", + scope=scope, + allowed_literal_origins=[], + allowed_origin_regex=None, + has_configured_allowed_origins=False, + ) + assert result, "Same-origin on public server should be allowed"