Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 47 additions & 84 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ unicase = "^2.9"
base64 = { version = "^0.22", optional = true }
zeroize = { version = "^1.8", features = ["zeroize_derive"], optional = true }
native-tls = { version = "^0.2", optional = true }
rustls = { version = "^0.23", optional = true }
rustls = { version = "^0.23", optional = true, default-features = false, features = ["logging", "std"] }
rustls-pemfile = { version = "^2.2", optional = true }
rustls-pki-types = { version = "^1.14", features = ["alloc"], optional = true }
tracing = { version = "0.1", optional = true }
webpki = { version = "^0.22", optional = true }
webpki-roots = { version = "^1.0", optional = true }

Expand All @@ -32,3 +33,4 @@ rust-tls = [
"auth",
]
auth = ["base64", "zeroize"]
tracing = ["dep:tracing"]
48 changes: 46 additions & 2 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use std::{
};
#[cfg(feature = "auth")]
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
#[cfg(feature = "tracing")]
use tracing::Span;

const CR_LF: &str = "\r\n";
const DEFAULT_REDIRECT_LIMIT: usize = 5;
Expand Down Expand Up @@ -499,7 +501,7 @@ impl<'a> RequestMessage<'a> {
/// assert_eq!(response.status_code(), StatusCode::new(200));
/// ```
///
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug)]
pub struct Request<'a> {
message: RequestMessage<'a>,
redirect_policy: RedirectPolicy<fn(&str) -> bool>,
Expand All @@ -508,6 +510,8 @@ pub struct Request<'a> {
write_timeout: Option<Duration>,
timeout: Duration,
root_cert_file_pem: Option<&'a Path>,
#[cfg(feature = "rust-tls")]
rustls_config: Option<std::sync::Arc<rustls::ClientConfig>>,
}

impl<'a> Request<'a> {
Expand All @@ -534,6 +538,8 @@ impl<'a> Request<'a> {
write_timeout: Some(Duration::from_secs(DEFAULT_CALL_TIMEOUT)),
timeout: Duration::from_secs(DEFAULT_REQ_TIMEOUT),
root_cert_file_pem: None,
#[cfg(feature = "rust-tls")]
rustls_config: None,
}
}

Expand Down Expand Up @@ -788,6 +794,15 @@ impl<'a> Request<'a> {
self
}

/// Sets a custom rustls `ClientConfig` to use for the TLS connection.
/// When set, overrides `root_cert_file_pem` and allows full control over
/// the root store, client certificate, and certificate verifier.
#[cfg(feature = "rust-tls")]
pub fn rustls_config(&mut self, config: std::sync::Arc<rustls::ClientConfig>) -> &mut Self {
self.rustls_config = Some(config);
self
}

/// Sets the redirect policy for the request.
///
/// # Examples
Expand Down Expand Up @@ -828,15 +843,37 @@ impl<'a> Request<'a> {
where
T: Write,
{
#[cfg(feature = "tracing")]
let span = tracing::info_span!(
"http_request",
otel.name = %format!("{} {}", self.message.method, self.message.uri.host().unwrap_or("")),
otel.kind = "client",
http.method = %self.message.method,
http.url = %self.message.uri,
http.status_code = tracing::field::Empty,
http.duration_ms = tracing::field::Empty,
);
#[cfg(feature = "tracing")]
let _guard = span.enter();
#[cfg(feature = "tracing")]
let start = Instant::now();

// Set up a stream.
let mut stream = Stream::connect(self.message.uri, self.connect_timeout)?;
stream.set_read_timeout(self.read_timeout)?;
stream.set_write_timeout(self.write_timeout)?;

#[cfg(any(feature = "native-tls", feature = "rust-tls"))]
#[cfg(feature = "native-tls")]
{
stream = Stream::try_to_https(stream, self.message.uri, self.root_cert_file_pem)?;
}
#[cfg(feature = "rust-tls")]
{
stream = match self.rustls_config.take() {
Some(config) => Stream::try_to_https_with_config(stream, self.message.uri, config)?,
None => Stream::try_to_https(stream, self.message.uri, self.root_cert_file_pem)?,
};
}

// Send the request message to the stream.
let request_msg = self.message.parse();
Expand Down Expand Up @@ -868,6 +905,13 @@ impl<'a> Request<'a> {
raw_response_head.receive(&receiver, deadline)?;
let response = Response::from_head(&raw_response_head)?;

#[cfg(feature = "tracing")]
{
let status: u16 = response.status_code().into();
Span::current().record("http.status_code", status as i64);
Span::current().record("http.duration_ms", start.elapsed().as_millis() as i64);
}

if response.status_code().is_redirect() {
if let Some(location) = response.headers().get("Location") {
if self.redirect_policy.follow(&location) {
Expand Down
18 changes: 18 additions & 0 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ impl Stream {
}
}

/// Tries to establish a secure connection using a pre-built `rustls::ClientConfig`.
/// Use this when you need a custom root store, client certificate, or certificate verifier.
#[cfg(feature = "rust-tls")]
pub fn try_to_https_with_config(stream: Stream, uri: &Uri, config: std::sync::Arc<rustls::ClientConfig>) -> Result<Stream, Error> {
match stream {
Stream::Http(http_stream) => {
if uri.scheme() == "https" {
let host = uri.host().ok_or(Error::Parse(ParseErr::UriErr))?;
let conn = tls::connect_with_config(config, host, http_stream)?;
Ok(Stream::Https(conn))
} else {
Ok(Stream::Http(http_stream))
}
}
Stream::Https(_) => Ok(stream),
}
}

/// Sets the read timeout on the underlying TCP stream.
pub fn set_read_timeout(&mut self, dur: Option<Duration>) -> Result<(), Error> {
match self {
Expand Down
Loading