Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion src/google/adk/cli/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,50 @@ def _get_scope_header(
return None


import ipaddress as _ipaddress

_LOOPBACK_HOSTNAMES = frozenset({"localhost"})


def _is_loopback_address(host: str) -> bool:
"""Return True if *host* (with or without a port) refers to a loopback address.

Handles all four forms produced by browsers and uvicorn:
- Plain IPv4: "127.0.0.1"
- IPv4 with port: "127.0.0.1:8000"
- Bracketed IPv6: "[::1]"
- Bracketed IPv6+port: "[::1]:8000"
- Plain IPv6 (scope): "::1" (ASGI server tuple value)
- Hostname: "localhost"
- Hostname with port: "localhost:8000"
"""
bare = host
if bare.startswith("["):
# Bracketed IPv6: [addr] or [addr]:port
end = bare.find("]")
if end != -1:
bare = bare[1:end]
elif bare.count(":") == 1:
# IPv4:port or hostname:port (IPv6 without brackets has > 1 colon)
bare = bare.rsplit(":", 1)[0]
if bare in _LOOPBACK_HOSTNAMES:
return True
if bare.startswith("127."):
return True
try:
return _ipaddress.ip_address(bare).is_loopback
except ValueError:
return False


def _get_server_host(scope: dict[str, Any]) -> Optional[str]:
"""Return the host the server is actually bound to (from ASGI server port)."""
server = scope.get("server")
if server and len(server) == 2:
return str(server[0])
return None


def _get_request_origin(scope: dict[str, Any]) -> Optional[str]:
"""Compute the effective origin for the current HTTP/WebSocket request."""
forwarded = _get_scope_header(scope, b"forwarded")
Expand Down Expand Up @@ -200,12 +244,39 @@ def _is_request_origin_allowed(
allowed_origin_regex: Optional[re.Pattern[str]],
has_configured_allowed_origins: bool,
) -> bool:
"""Validate an Origin header against explicit config or same-origin."""
"""Validate an Origin header against explicit config or same-origin.

DNS-rebinding protection: when the server is bound to a loopback address
(127.0.0.1 / ::1 / localhost) and no explicit allow-origins have been
configured, we additionally require that the request's Origin header also
resolves to a loopback host. This prevents a DNS-rebinding attack where
an external page temporarily resolves to 127.0.0.1 and then POSTs to the
local development server by matching its own (evil.com) origin against the
Host header it controls.
"""
if has_configured_allowed_origins and _is_origin_allowed(
origin, allowed_literal_origins, allowed_origin_regex
):
return True

# DNS-rebinding guard: if the server is on loopback and no explicit
# allow-origins list is configured, only permit origins whose host is also
# loopback. This mirrors the protection used by the MCP go-sdk SSEHandler.
server_host = _get_server_host(scope)
if (
not has_configured_allowed_origins
and server_host is not None
and _is_loopback_address(server_host)
):
try:
from urllib.parse import urlparse # noqa: PLC0415 (local import OK here)

origin_host = urlparse(origin).hostname or ""
except Exception: # pylint: disable=broad-except
return False
if not _is_loopback_address(origin_host):
return False

request_origin = _get_request_origin(scope)
if request_origin is None:
return False
Expand Down
170 changes: 170 additions & 0 deletions tests/unittests/cli/test_dns_rebinding_protection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for DNS-rebinding protection in _OriginCheckMiddleware."""

import pytest

from google.adk.cli.api_server import (
_is_loopback_address,
_is_request_origin_allowed,
)


class TestIsLoopbackAddress:
"""Unit tests for _is_loopback_address."""

@pytest.mark.parametrize(
"host",
[
"127.0.0.1",
"localhost",
"::1",
"[::1]",
"127.0.0.1:8000",
"localhost:8000",
"[::1]:8000",
"127.1.2.3", # any 127.x.x.x is loopback
],
)
def test_loopback_hosts(self, host: str):
assert _is_loopback_address(host), f"{host!r} should be loopback"

@pytest.mark.parametrize(
"host",
[
"evil.com",
"0.0.0.0",
"192.168.1.1",
"10.0.0.1",
"128.0.0.1",
"",
],
)
def test_non_loopback_hosts(self, host: str):
assert not _is_loopback_address(host), f"{host!r} should NOT be loopback"


class TestDnsRebindingProtection:
"""Tests that DNS-rebinding attacks are blocked when server is on loopback."""

def _make_scope(self, server_host: str = "127.0.0.1", host_header: str = "127.0.0.1:8000") -> dict:
"""Build a minimal ASGI scope for testing."""
return {
"type": "http",
"method": "POST",
"server": (server_host, 8000),
"headers": [
(b"host", host_header.encode()),
],
"scheme": "http",
}

# --- DNS rebinding scenarios (should be BLOCKED) ---

def test_dns_rebinding_evil_origin_loopback_server_no_configured_origins(self):
"""Attacker page (evil.com) DNS-rebinds to 127.0.0.1 and sends a POST.

Browser sends Origin: http://evil.com, Host: evil.com.
Server is bound to 127.0.0.1.
No explicit allow-origins configured.
Expected: BLOCKED.
"""
scope = self._make_scope(server_host="127.0.0.1", host_header="evil.com:8000")
result = _is_request_origin_allowed(
origin="http://evil.com",
scope=scope,
allowed_literal_origins=[],
allowed_origin_regex=None,
has_configured_allowed_origins=False,
)
assert not result, "DNS-rebinding from evil.com should be blocked on loopback server"

def test_dns_rebinding_localhost_server(self):
"""Same attack, server bound as 'localhost'."""
scope = self._make_scope(server_host="localhost", host_header="evil.com")
result = _is_request_origin_allowed(
origin="http://evil.com",
scope=scope,
allowed_literal_origins=[],
allowed_origin_regex=None,
has_configured_allowed_origins=False,
)
assert not result

def test_dns_rebinding_ipv6_loopback_server(self):
"""Same attack, server bound to ::1."""
scope = self._make_scope(server_host="::1", host_header="evil.com")
result = _is_request_origin_allowed(
origin="http://evil.com",
scope=scope,
allowed_literal_origins=[],
allowed_origin_regex=None,
has_configured_allowed_origins=False,
)
assert not result

# --- Legitimate same-origin requests (should be ALLOWED) ---

def test_same_origin_localhost_allowed(self):
"""Legitimate browser request from localhost UI to localhost server."""
scope = self._make_scope(server_host="127.0.0.1", host_header="127.0.0.1:8000")
result = _is_request_origin_allowed(
origin="http://127.0.0.1:8000",
scope=scope,
allowed_literal_origins=[],
allowed_origin_regex=None,
has_configured_allowed_origins=False,
)
assert result, "Same-origin localhost request should be allowed"

def test_same_origin_localhost_named(self):
"""Browser opens http://localhost:8000 -> requests to localhost:8000."""
scope = self._make_scope(server_host="127.0.0.1", host_header="localhost:8000")
result = _is_request_origin_allowed(
origin="http://localhost:8000",
scope=scope,
allowed_literal_origins=[],
allowed_origin_regex=None,
has_configured_allowed_origins=False,
)
assert result

# --- Explicit allow-origins configured (allow-list bypasses DNS guard) ---

def test_explicit_allowlist_overrides_dns_rebinding_guard(self):
"""If the developer explicitly allows evil.com, it should be permitted."""
scope = self._make_scope(server_host="127.0.0.1", host_header="evil.com")
result = _is_request_origin_allowed(
origin="http://evil.com",
scope=scope,
allowed_literal_origins=["http://evil.com"],
allowed_origin_regex=None,
has_configured_allowed_origins=True,
)
assert result, "Explicitly allowed origin should still pass"

# --- Non-loopback server (protection does not apply) ---

def test_non_loopback_server_no_dns_guard(self):
"""Server bound to 0.0.0.0 — DNS guard must not interfere with same-origin check."""
scope = self._make_scope(server_host="0.0.0.0", host_header="example.com:8000")
result = _is_request_origin_allowed(
origin="http://example.com:8000",
scope=scope,
allowed_literal_origins=[],
allowed_origin_regex=None,
has_configured_allowed_origins=False,
)
assert result, "Same-origin on public server should be allowed"