diff --git a/src/config.rs b/src/config.rs index 508478b..3554308 100644 --- a/src/config.rs +++ b/src/config.rs @@ -38,12 +38,7 @@ impl std::str::FromStr for DnsUpstream { if s.starts_with("https://") { Ok(DnsUpstream::Https(s.to_string())) } else { - s.parse::().map(DnsUpstream::Udp).map_err(|e| { - format!( - "invalid DNS upstream '{}': expected socket address or https:// URL: {}", - s, e - ) - }) + resolve_dns_upstream_addr(s).map(DnsUpstream::Udp) } } } @@ -176,6 +171,28 @@ fn parse_upstream_addr(addr: &str) -> Result { .ok_or_else(|| "hostname resolved to no addresses".to_string()) } +fn resolve_dns_upstream_addr(input: &str) -> Result { + if let Ok(addr) = input.parse::() { + return Ok(addr); + } + + input + .to_socket_addrs() + .map_err(|e| { + format!( + "invalid DNS upstream '{}': expected host:port or https:// URL: {}", + input, e + ) + })? + .next() + .ok_or_else(|| { + format!( + "invalid DNS upstream '{}': hostname resolved to no addresses", + input + ) + }) +} + /// Comma-separated list of TCP ports for firewall redirection. /// /// Used with the `--ports` flag to restrict which ports are redirected. @@ -436,6 +453,18 @@ mod tests { } } + #[test] + fn test_dns_upstream_parse_udp_hostname() { + let upstream: DnsUpstream = "localhost:53".parse().unwrap(); + match upstream { + DnsUpstream::Udp(addr) => { + assert!(addr.ip().is_loopback()); + assert_eq!(addr.port(), 53); + } + _ => panic!("expected Udp variant"), + } + } + #[test] fn test_dns_upstream_parse_invalid() { let result: Result = "not-valid".parse();