Skip to content

Commit 1b06803

Browse files
committed
fix(http): add host check
1 parent ac749e3 commit 1b06803

3 files changed

Lines changed: 291 additions & 1 deletion

File tree

crates/rmcp/src/transport/common/http_header.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub const JSON_MIME_TYPE: &str = "application/json";
77
/// Reserved headers that must not be overridden by user-supplied custom headers.
88
/// `MCP-Protocol-Version` is in this list but is allowed through because the worker
99
/// injects it after initialization.
10+
#[allow(dead_code)]
1011
pub(crate) const RESERVED_HEADERS: &[&str] = &[
1112
"accept",
1213
HEADER_SESSION_ID,
@@ -36,6 +37,7 @@ pub(crate) fn validate_custom_header(name: &http::HeaderName) -> Result<(), Stri
3637

3738
/// Extracts the `scope=` parameter from a `WWW-Authenticate` header value.
3839
/// Handles both quoted (`scope="files:read files:write"`) and unquoted (`scope=read:data`) forms.
40+
#[allow(dead_code)]
3941
pub(crate) fn extract_scope_from_header(header: &str) -> Option<String> {
4042
let header_lowercase = header.to_ascii_lowercase();
4143
let scope_key = "scope=";

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};
22

33
use bytes::Bytes;
44
use futures::{StreamExt, future::BoxFuture};
5-
use http::{Method, Request, Response, header::ALLOW};
5+
use http::{HeaderMap, Method, Request, Response, header::ALLOW};
66
use http_body::Body;
77
use http_body_util::{BodyExt, Full, combinators::BoxBody};
88
use tokio_stream::wrappers::ReceiverStream;
@@ -29,6 +29,7 @@ use crate::{
2929
},
3030
};
3131

32+
#[non_exhaustive]
3233
#[derive(Debug, Clone)]
3334
#[non_exhaustive]
3435
pub struct StreamableHttpServerConfig {
@@ -49,6 +50,16 @@ pub struct StreamableHttpServerConfig {
4950
/// When this token is cancelled, all active sessions are terminated and
5051
/// the server stops accepting new requests.
5152
pub cancellation_token: CancellationToken,
53+
/// Allowed hostnames or `host:port` authorities for inbound `Host` validation.
54+
///
55+
/// By default, Streamable HTTP servers only accept loopback hosts to
56+
/// prevent DNS rebinding attacks against locally running servers. Public
57+
/// deployments should override this list with their own hostnames.
58+
/// examples:
59+
/// allowed_hosts = ["localhost", "127.0.0.1", "0.0.0.0"]
60+
/// or with ports:
61+
/// allowed_hosts = ["example.com", "example.com:8080"]
62+
pub allowed_hosts: Vec<String>,
5263
}
5364

5465
impl Default for StreamableHttpServerConfig {
@@ -59,10 +70,50 @@ impl Default for StreamableHttpServerConfig {
5970
stateful_mode: true,
6071
json_response: false,
6172
cancellation_token: CancellationToken::new(),
73+
allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()],
6274
}
6375
}
6476
}
6577

78+
impl StreamableHttpServerConfig {
79+
pub fn with_allowed_hosts(
80+
mut self,
81+
allowed_hosts: impl IntoIterator<Item = impl Into<String>>,
82+
) -> Self {
83+
self.allowed_hosts = allowed_hosts.into_iter().map(Into::into).collect();
84+
self
85+
}
86+
/// Disable allowed hosts. This will allow requests with any `Host` or `Origin` header, which is NOT recommended for public deployments.
87+
pub fn disable_allowed_hosts(mut self) -> Self {
88+
self.allowed_hosts.clear();
89+
self
90+
}
91+
pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self {
92+
self.sse_keep_alive = duration;
93+
self
94+
}
95+
96+
pub fn with_sse_retry(mut self, duration: Option<Duration>) -> Self {
97+
self.sse_retry = duration;
98+
self
99+
}
100+
101+
pub fn with_stateful_mode(mut self, stateful: bool) -> Self {
102+
self.stateful_mode = stateful;
103+
self
104+
}
105+
106+
pub fn with_json_response(mut self, json_response: bool) -> Self {
107+
self.json_response = json_response;
108+
self
109+
}
110+
111+
pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
112+
self.cancellation_token = token;
113+
self
114+
}
115+
}
116+
66117
impl StreamableHttpServerConfig {
67118
pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self {
68119
self.sse_keep_alive = duration;
@@ -130,6 +181,87 @@ fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), Box
130181
Ok(())
131182
}
132183

184+
fn forbidden_response(message: impl Into<String>) -> BoxResponse {
185+
Response::builder()
186+
.status(http::StatusCode::FORBIDDEN)
187+
.body(Full::new(Bytes::from(message.into())).boxed())
188+
.expect("valid response")
189+
}
190+
191+
fn normalize_host(host: &str) -> String {
192+
host.trim_matches('[')
193+
.trim_matches(']')
194+
.to_ascii_lowercase()
195+
}
196+
197+
#[derive(Debug, Clone, PartialEq, Eq)]
198+
struct NormalizedAuthority {
199+
host: String,
200+
port: Option<u16>,
201+
}
202+
203+
fn normalize_authority(host: &str, port: Option<u16>) -> NormalizedAuthority {
204+
NormalizedAuthority {
205+
host: normalize_host(host),
206+
port,
207+
}
208+
}
209+
210+
fn parse_allowed_authority(allowed: &str) -> Option<NormalizedAuthority> {
211+
let allowed = allowed.trim();
212+
if allowed.is_empty() {
213+
return None;
214+
}
215+
216+
if let Ok(authority) = http::uri::Authority::try_from(allowed) {
217+
return Some(normalize_authority(authority.host(), authority.port_u16()));
218+
}
219+
220+
Some(normalize_authority(allowed, None))
221+
}
222+
223+
fn host_is_allowed(host: &NormalizedAuthority, allowed_hosts: &[String]) -> bool {
224+
if allowed_hosts.is_empty() {
225+
// If the allowed hosts list is empty, allow all hosts (not recommended).
226+
return true;
227+
}
228+
allowed_hosts
229+
.iter()
230+
.filter_map(|allowed| parse_allowed_authority(allowed))
231+
.any(|allowed| {
232+
allowed.host == host.host
233+
&& match allowed.port {
234+
Some(port) => host.port == Some(port),
235+
None => true,
236+
}
237+
})
238+
}
239+
240+
fn parse_host_header(headers: &HeaderMap) -> Result<NormalizedAuthority, BoxResponse> {
241+
let Some(host) = headers.get(http::header::HOST) else {
242+
return Err(forbidden_response("Forbidden:missing_host header"));
243+
};
244+
245+
let host = host
246+
.to_str()
247+
.map_err(|_| forbidden_response("Forbidden: Invalid Host header encoding"))?;
248+
let authority = http::uri::Authority::try_from(host)
249+
.map_err(|_| forbidden_response("Forbidden: Invalid Host header"))?;
250+
Ok(normalize_authority(authority.host(), authority.port_u16()))
251+
}
252+
253+
fn validate_dns_rebinding_headers(
254+
headers: &HeaderMap,
255+
config: &StreamableHttpServerConfig,
256+
) -> Result<(), BoxResponse> {
257+
let host = parse_host_header(headers)?;
258+
if !host_is_allowed(&host, &config.allowed_hosts) {
259+
return Err(forbidden_response("Forbidden: Host header is not allowed"));
260+
}
261+
262+
Ok(())
263+
}
264+
133265
/// # Streamable HTTP server
134266
///
135267
/// An HTTP service that implements the
@@ -279,6 +411,9 @@ where
279411
B: Body + Send + 'static,
280412
B::Error: Display,
281413
{
414+
if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) {
415+
return response;
416+
}
282417
let method = request.method().clone();
283418
let allowed_methods = match self.config.stateful_mode {
284419
true => "GET, POST, DELETE",

crates/rmcp/tests/test_custom_headers.rs

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,3 +870,156 @@ fn test_protocol_version_utilities() {
870870
assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_03_26));
871871
assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_06_18));
872872
}
873+
874+
/// Integration test: Verify server validates only the Host header for DNS rebinding protection
875+
#[tokio::test]
876+
#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))]
877+
async fn test_server_validates_host_header_for_dns_rebinding_protection() {
878+
use std::sync::Arc;
879+
880+
use bytes::Bytes;
881+
use http::{Method, Request, header::CONTENT_TYPE};
882+
use http_body_util::Full;
883+
use rmcp::{
884+
handler::server::ServerHandler,
885+
model::{ServerCapabilities, ServerInfo},
886+
transport::streamable_http_server::{
887+
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
888+
},
889+
};
890+
use serde_json::json;
891+
892+
#[derive(Clone)]
893+
struct TestHandler;
894+
895+
impl ServerHandler for TestHandler {
896+
fn get_info(&self) -> ServerInfo {
897+
ServerInfo::new(ServerCapabilities::builder().build())
898+
}
899+
}
900+
901+
let service = StreamableHttpService::new(
902+
|| Ok(TestHandler),
903+
Arc::new(LocalSessionManager::default()),
904+
StreamableHttpServerConfig::default(),
905+
);
906+
907+
let init_body = json!({
908+
"jsonrpc": "2.0",
909+
"id": 1,
910+
"method": "initialize",
911+
"params": {
912+
"protocolVersion": "2025-03-26",
913+
"capabilities": {},
914+
"clientInfo": {
915+
"name": "test-client",
916+
"version": "1.0.0"
917+
}
918+
}
919+
});
920+
921+
let allowed_request = Request::builder()
922+
.method(Method::POST)
923+
.header("Accept", "application/json, text/event-stream")
924+
.header(CONTENT_TYPE, "application/json")
925+
.header("Host", "localhost:8080")
926+
.header("Origin", "http://localhost:8080")
927+
.body(Full::new(Bytes::from(init_body.to_string())))
928+
.unwrap();
929+
930+
let response = service.handle(allowed_request).await;
931+
assert_eq!(response.status(), http::StatusCode::OK);
932+
933+
let bad_host_request = Request::builder()
934+
.method(Method::POST)
935+
.header("Accept", "application/json, text/event-stream")
936+
.header(CONTENT_TYPE, "application/json")
937+
.header("Host", "attacker.example")
938+
.body(Full::new(Bytes::from(init_body.to_string())))
939+
.unwrap();
940+
941+
let response = service.handle(bad_host_request).await;
942+
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
943+
944+
let ignored_origin_request = Request::builder()
945+
.method(Method::POST)
946+
.header("Accept", "application/json, text/event-stream")
947+
.header(CONTENT_TYPE, "application/json")
948+
.header("Host", "localhost:8080")
949+
.header("Origin", "http://attacker.example")
950+
.body(Full::new(Bytes::from(init_body.to_string())))
951+
.unwrap();
952+
953+
let response = service.handle(ignored_origin_request).await;
954+
assert_eq!(response.status(), http::StatusCode::OK);
955+
}
956+
957+
/// Integration test: Verify server can enforce an allowed Host port when configured
958+
#[tokio::test]
959+
#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))]
960+
async fn test_server_validates_host_header_port_for_dns_rebinding_protection() {
961+
use std::sync::Arc;
962+
963+
use bytes::Bytes;
964+
use http::{Method, Request, header::CONTENT_TYPE};
965+
use http_body_util::Full;
966+
use rmcp::{
967+
handler::server::ServerHandler,
968+
model::{ServerCapabilities, ServerInfo},
969+
transport::streamable_http_server::{
970+
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
971+
},
972+
};
973+
use serde_json::json;
974+
975+
#[derive(Clone)]
976+
struct TestHandler;
977+
978+
impl ServerHandler for TestHandler {
979+
fn get_info(&self) -> ServerInfo {
980+
ServerInfo::new(ServerCapabilities::builder().build())
981+
}
982+
}
983+
984+
let service = StreamableHttpService::new(
985+
|| Ok(TestHandler),
986+
Arc::new(LocalSessionManager::default()),
987+
StreamableHttpServerConfig::default().with_allowed_hosts(["localhost:8080"]),
988+
);
989+
990+
let init_body = json!({
991+
"jsonrpc": "2.0",
992+
"id": 1,
993+
"method": "initialize",
994+
"params": {
995+
"protocolVersion": "2025-03-26",
996+
"capabilities": {},
997+
"clientInfo": {
998+
"name": "test-client",
999+
"version": "1.0.0"
1000+
}
1001+
}
1002+
});
1003+
1004+
let allowed_request = Request::builder()
1005+
.method(Method::POST)
1006+
.header("Accept", "application/json, text/event-stream")
1007+
.header(CONTENT_TYPE, "application/json")
1008+
.header("Host", "localhost:8080")
1009+
.body(Full::new(Bytes::from(init_body.to_string())))
1010+
.unwrap();
1011+
1012+
let response = service.handle(allowed_request).await;
1013+
assert_eq!(response.status(), http::StatusCode::OK);
1014+
1015+
let wrong_port_request = Request::builder()
1016+
.method(Method::POST)
1017+
.header("Accept", "application/json, text/event-stream")
1018+
.header(CONTENT_TYPE, "application/json")
1019+
.header("Host", "localhost:3000")
1020+
.body(Full::new(Bytes::from(init_body.to_string())))
1021+
.unwrap();
1022+
1023+
let response = service.handle(wrong_port_request).await;
1024+
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
1025+
}

0 commit comments

Comments
 (0)