Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ struct Config {
#[cfg(feature = "cookies")]
cookie_store: Option<Arc<dyn cookie::CookieStore>>,
trust_dns: bool,
#[cfg(feature = "trust-dns")]
ip_filter: fn(std::net::IpAddr) -> bool,
error: Option<crate::Error>,
https_only: bool,
#[cfg(feature = "http3")]
Expand Down Expand Up @@ -219,6 +221,8 @@ impl ClientBuilder {
local_address: None,
nodelay: true,
trust_dns: cfg!(feature = "trust-dns"),
#[cfg(feature = "trust-dns")]
ip_filter: |_| true,
#[cfg(feature = "cookies")]
cookie_store: None,
https_only: false,
Expand Down Expand Up @@ -270,7 +274,7 @@ impl ClientBuilder {
let mut resolver: Arc<dyn Resolve> = match config.trust_dns {
false => Arc::new(GaiResolver::new()),
#[cfg(feature = "trust-dns")]
true => Arc::new(TrustDnsResolver::default()),
true => Arc::new(TrustDnsResolver::new(config.ip_filter)),
#[cfg(not(feature = "trust-dns"))]
true => unreachable!("trust-dns shouldn't be enabled unless the feature is"),
};
Expand Down Expand Up @@ -689,6 +693,7 @@ impl ClientBuilder {
proxies,
proxies_maybe_http_auth,
https_only: config.https_only,
ip_filter: config.ip_filter,
}),
})
}
Expand Down Expand Up @@ -1543,6 +1548,17 @@ impl ClientBuilder {
}
}

/// Adds a filter for valid IP addresses during DNS lookup.
///
/// # Optional
///
/// This requires the optional `trust-dns` feature to be enabled.
#[cfg(feature = "trust-dns")]
pub fn ip_filter(mut self, filter: fn(std::net::IpAddr) -> bool) -> ClientBuilder {
self.config.ip_filter = filter;
self
}

/// Restrict the Client to be used with HTTPS only requests.
///
/// Defaults to false.
Expand Down Expand Up @@ -1797,6 +1813,11 @@ impl Client {
}
}

if let Err(err) = validate_url(self.inner.ip_filter, &url) {
return Pending {
inner: PendingInner::Error(Some(err)),
};
}
let uri = expect_uri(&url);

let (reusable, body) = match body {
Expand Down Expand Up @@ -2052,6 +2073,7 @@ struct ClientRef {
proxies: Arc<Vec<Proxy>>,
proxies_maybe_http_auth: bool,
https_only: bool,
ip_filter: fn(IpAddr) -> bool,
}

impl ClientRef {
Expand Down Expand Up @@ -2165,6 +2187,8 @@ impl PendingRequest {
}
self.retry_count += 1;

// XXX: We can't return an `Err` here, as we are mutating the `in_flight` future to restart it.
// However, at this point, we already validated `self.url` so it should be good.
let uri = expect_uri(&self.url);

*self.as_mut().in_flight().get_mut() = match *self.as_mut().in_flight().as_ref() {
Expand Down Expand Up @@ -2379,6 +2403,11 @@ impl Future for PendingRequest {
std::mem::replace(self.as_mut().headers(), HeaderMap::new());

remove_sensitive_headers(&mut headers, &self.url, &self.urls);

if let Err(err) = validate_url(self.client.ip_filter, &self.url) {
return Poll::Ready(Err(err));
}

let uri = expect_uri(&self.url);
let body = match self.body {
Some(Some(ref body)) => Body::reusable(body.clone()),
Expand Down Expand Up @@ -2476,6 +2505,20 @@ fn add_cookie_header(headers: &mut HeaderMap, cookie_store: &dyn cookie::CookieS
}
}

fn validate_url(ip_filter: fn(IpAddr) -> bool, url: &Url) -> Result<(), crate::Error> {
let is_valid_ip = match url.host() {
Some(url::Host::Ipv4(ip)) => (ip_filter)(IpAddr::V4(ip)),
Some(url::Host::Ipv6(ip)) => (ip_filter)(IpAddr::V6(ip)),
_ => true,
};

if !is_valid_ip {
let e = trust_dns_resolver::error::ResolveError::from("destination is restricted");
return Err(crate::Error::new(crate::error::Kind::Request, Some(e)));
}
Ok(())
}

#[cfg(test)]
mod tests {
#[tokio::test]
Expand Down
27 changes: 25 additions & 2 deletions src/dns/trust_dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,45 @@ use std::sync::Arc;
use super::{Addrs, Resolve, Resolving};

/// Wrapper around an `AsyncResolver`, which implements the `Resolve` trait.
#[derive(Debug, Default, Clone)]
#[derive(Debug, Clone)]
pub(crate) struct TrustDnsResolver {
/// Since we might not have been called in the context of a
/// Tokio Runtime in initialization, so we must delay the actual
/// construction of the resolver.
state: Arc<OnceCell<TokioAsyncResolver>>,
filter: fn(std::net::IpAddr) -> bool,
}

struct SocketAddrs {
iter: LookupIpIntoIter,
filter: fn(std::net::IpAddr) -> bool,
}

impl TrustDnsResolver {
pub fn new(filter: fn(std::net::IpAddr) -> bool) -> Self {
TrustDnsResolver {
state: Default::default(),
filter,
}
}
}

impl Resolve for TrustDnsResolver {
fn resolve(&self, name: Name) -> Resolving {
let resolver = self.clone();
Box::pin(async move {
let filter = resolver.filter;
let resolver = resolver.state.get_or_try_init(new_resolver)?;

let lookup = resolver.lookup_ip(name.as_str()).await?;
if !lookup.iter().any(filter) {
let e = trust_dns_resolver::error::ResolveError::from("destination is restricted");
return Err(e.into());
}

let addrs: Addrs = Box::new(SocketAddrs {
iter: lookup.into_iter(),
filter,
});
Ok(addrs)
})
Expand All @@ -43,7 +61,12 @@ impl Iterator for SocketAddrs {
type Item = SocketAddr;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|ip_addr| SocketAddr::new(ip_addr, 0))
loop {
let ip_addr = self.iter.next()?;
if (self.filter)(ip_addr) {
return Some(SocketAddr::new(ip_addr, 0));
}
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ impl Error {
if hyper_err.is_connect() {
return true;
}
} else if err.downcast_ref::<trust_dns_resolver::error::ResolveError>().is_some() {
return true;
}

source = err.source();
Expand Down