@@ -2,7 +2,7 @@ use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};
22
33use bytes:: Bytes ;
44use futures:: { StreamExt , future:: BoxFuture } ;
5- use http:: { Method , Request , Response , header:: ALLOW } ;
5+ use http:: { HeaderMap , Method , Request , Response , header:: ALLOW } ;
66use http_body:: Body ;
77use http_body_util:: { BodyExt , Full , combinators:: BoxBody } ;
88use tokio_stream:: wrappers:: ReceiverStream ;
@@ -29,6 +29,7 @@ use crate::{
2929 } ,
3030} ;
3131
32+ #[ non_exhaustive]
3233#[ derive( Debug , Clone ) ]
3334#[ non_exhaustive]
3435pub struct StreamableHttpServerConfig {
@@ -49,6 +50,16 @@ pub struct StreamableHttpServerConfig {
4950 /// When this token is cancelled, all active sessions are terminated and
5051 /// the server stops accepting new requests.
5152 pub cancellation_token : CancellationToken ,
53+ /// Allowed hostnames or `host:port` authorities for inbound `Host` validation.
54+ ///
55+ /// By default, Streamable HTTP servers only accept loopback hosts to
56+ /// prevent DNS rebinding attacks against locally running servers. Public
57+ /// deployments should override this list with their own hostnames.
58+ /// examples:
59+ /// allowed_hosts = ["localhost", "127.0.0.1", "0.0.0.0"]
60+ /// or with ports:
61+ /// allowed_hosts = ["example.com", "example.com:8080"]
62+ pub allowed_hosts : Vec < String > ,
5263}
5364
5465impl Default for StreamableHttpServerConfig {
@@ -59,10 +70,50 @@ impl Default for StreamableHttpServerConfig {
5970 stateful_mode : true ,
6071 json_response : false ,
6172 cancellation_token : CancellationToken :: new ( ) ,
73+ allowed_hosts : vec ! [ "localhost" . into( ) , "127.0.0.1" . into( ) , "::1" . into( ) ] ,
6274 }
6375 }
6476}
6577
78+ impl StreamableHttpServerConfig {
79+ pub fn with_allowed_hosts (
80+ mut self ,
81+ allowed_hosts : impl IntoIterator < Item = impl Into < String > > ,
82+ ) -> Self {
83+ self . allowed_hosts = allowed_hosts. into_iter ( ) . map ( Into :: into) . collect ( ) ;
84+ self
85+ }
86+ /// Disable allowed hosts. This will allow requests with any `Host` or `Origin` header, which is NOT recommended for public deployments.
87+ pub fn disable_allowed_hosts ( mut self ) -> Self {
88+ self . allowed_hosts . clear ( ) ;
89+ self
90+ }
91+ pub fn with_sse_keep_alive ( mut self , duration : Option < Duration > ) -> Self {
92+ self . sse_keep_alive = duration;
93+ self
94+ }
95+
96+ pub fn with_sse_retry ( mut self , duration : Option < Duration > ) -> Self {
97+ self . sse_retry = duration;
98+ self
99+ }
100+
101+ pub fn with_stateful_mode ( mut self , stateful : bool ) -> Self {
102+ self . stateful_mode = stateful;
103+ self
104+ }
105+
106+ pub fn with_json_response ( mut self , json_response : bool ) -> Self {
107+ self . json_response = json_response;
108+ self
109+ }
110+
111+ pub fn with_cancellation_token ( mut self , token : CancellationToken ) -> Self {
112+ self . cancellation_token = token;
113+ self
114+ }
115+ }
116+
66117impl StreamableHttpServerConfig {
67118 pub fn with_sse_keep_alive ( mut self , duration : Option < Duration > ) -> Self {
68119 self . sse_keep_alive = duration;
@@ -130,6 +181,87 @@ fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), Box
130181 Ok ( ( ) )
131182}
132183
184+ fn forbidden_response ( message : impl Into < String > ) -> BoxResponse {
185+ Response :: builder ( )
186+ . status ( http:: StatusCode :: FORBIDDEN )
187+ . body ( Full :: new ( Bytes :: from ( message. into ( ) ) ) . boxed ( ) )
188+ . expect ( "valid response" )
189+ }
190+
191+ fn normalize_host ( host : & str ) -> String {
192+ host. trim_matches ( '[' )
193+ . trim_matches ( ']' )
194+ . to_ascii_lowercase ( )
195+ }
196+
197+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
198+ struct NormalizedAuthority {
199+ host : String ,
200+ port : Option < u16 > ,
201+ }
202+
203+ fn normalize_authority ( host : & str , port : Option < u16 > ) -> NormalizedAuthority {
204+ NormalizedAuthority {
205+ host : normalize_host ( host) ,
206+ port,
207+ }
208+ }
209+
210+ fn parse_allowed_authority ( allowed : & str ) -> Option < NormalizedAuthority > {
211+ let allowed = allowed. trim ( ) ;
212+ if allowed. is_empty ( ) {
213+ return None ;
214+ }
215+
216+ if let Ok ( authority) = http:: uri:: Authority :: try_from ( allowed) {
217+ return Some ( normalize_authority ( authority. host ( ) , authority. port_u16 ( ) ) ) ;
218+ }
219+
220+ Some ( normalize_authority ( allowed, None ) )
221+ }
222+
223+ fn host_is_allowed ( host : & NormalizedAuthority , allowed_hosts : & [ String ] ) -> bool {
224+ if allowed_hosts. is_empty ( ) {
225+ // If the allowed hosts list is empty, allow all hosts (not recommended).
226+ return true ;
227+ }
228+ allowed_hosts
229+ . iter ( )
230+ . filter_map ( |allowed| parse_allowed_authority ( allowed) )
231+ . any ( |allowed| {
232+ allowed. host == host. host
233+ && match allowed. port {
234+ Some ( port) => host. port == Some ( port) ,
235+ None => true ,
236+ }
237+ } )
238+ }
239+
240+ fn parse_host_header ( headers : & HeaderMap ) -> Result < NormalizedAuthority , BoxResponse > {
241+ let Some ( host) = headers. get ( http:: header:: HOST ) else {
242+ return Err ( forbidden_response ( "Forbidden:missing_host header" ) ) ;
243+ } ;
244+
245+ let host = host
246+ . to_str ( )
247+ . map_err ( |_| forbidden_response ( "Forbidden: Invalid Host header encoding" ) ) ?;
248+ let authority = http:: uri:: Authority :: try_from ( host)
249+ . map_err ( |_| forbidden_response ( "Forbidden: Invalid Host header" ) ) ?;
250+ Ok ( normalize_authority ( authority. host ( ) , authority. port_u16 ( ) ) )
251+ }
252+
253+ fn validate_dns_rebinding_headers (
254+ headers : & HeaderMap ,
255+ config : & StreamableHttpServerConfig ,
256+ ) -> Result < ( ) , BoxResponse > {
257+ let host = parse_host_header ( headers) ?;
258+ if !host_is_allowed ( & host, & config. allowed_hosts ) {
259+ return Err ( forbidden_response ( "Forbidden: Host header is not allowed" ) ) ;
260+ }
261+
262+ Ok ( ( ) )
263+ }
264+
133265/// # Streamable HTTP server
134266///
135267/// An HTTP service that implements the
@@ -279,6 +411,9 @@ where
279411 B : Body + Send + ' static ,
280412 B :: Error : Display ,
281413 {
414+ if let Err ( response) = validate_dns_rebinding_headers ( request. headers ( ) , & self . config ) {
415+ return response;
416+ }
282417 let method = request. method ( ) . clone ( ) ;
283418 let allowed_methods = match self . config . stateful_mode {
284419 true => "GET, POST, DELETE" ,
0 commit comments