Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ instant-acme = { version = "0.8.2", default-features = false, features = [
"ring",
"hyper-rustls",
] }
mail-auth = "0.8.0"
mail-parser = { version = "0.11.3", features = ["full_encoding"] }
mockall = "0.13.0"
parse-size = { version = "1.1.0", features = ["std"] }
prometheus = "0.14.0"
Expand All @@ -64,6 +66,7 @@ rustls = { version = "0.23.18", default-features = false, features = [
"brotli",
] }
serde = { version = "1.0.214", features = ["derive"] }
smtp-proto = "0.2.1"
socket2 = "0.6.0"
strum = { version = "0.28", features = ["derive"] }
strum_macros = "0.28"
Expand Down
9 changes: 9 additions & 0 deletions ic-bn-lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ ic-bn-lib-common = "0.1"
indoc = "2.0.6"
instant-acme = { workspace = true, optional = true }
itertools = "0.14.0"
mail-auth = { workspace = true }
mail-parser = { workspace = true }
moka = { version = "0.12.15", features = ["sync", "future"] }
nix = { version = "0.30.0", features = ["signal"] }
ppp = "2.3.0"
Expand Down Expand Up @@ -107,13 +109,15 @@ sev = { version = "7.1.0", optional = true, default-features = false, features =
] }
sha1 = "0.11.0"
sha2 = { version = "0.11.0", optional = true }
smtp-proto = { workspace = true }
socket2 = "0.6.0"
strum = { workspace = true }
strum_macros = { workspace = true }
systemstat = "0.2.3"
tar = { version = "0.4.44", optional = true }
tempfile = "3.23"
thiserror = { workspace = true }
tracing-subscriber = "0.3"
tokio = { version = "1.47.0", features = ["full"] }
tokio-util = { workspace = true }
tokio-rustls = { version = "0.26.0", default-features = false, features = [
Expand Down Expand Up @@ -145,6 +149,7 @@ mock-io = { version = "0.3.2", features = ["full"] }
# Do not upgrade unless you want to upgrade `rand` too
rand_regex = "=0.17.0"
tempfile = "3.20.0"
tokio-test = "0.4"

[[bench]]
name = "vector"
Expand All @@ -154,3 +159,7 @@ required-features = ["vector"]
[package.metadata.cargo-all-features]
# Limit feature combinations to reduce test duration
max_combination_size = 2

[[bin]]
name = "smtp-server"
path = "tools/smtp_server.rs"
79 changes: 1 addition & 78 deletions ic-bn-lib/src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,8 @@ pub mod proxy;
pub mod server;
pub mod shed;

use std::{
io,
pin::{Pin, pin},
sync::{Arc, atomic::Ordering},
task::{Context, Poll},
};

use axum::response::{IntoResponse, Redirect};
use http::{HeaderMap, Method, Request, StatusCode, Uri, Version, header::HOST, uri::PathAndQuery};
use ic_bn_lib_common::types::http::Stats;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

#[cfg(feature = "clients-hyper")]
pub use client::clients_hyper::{HyperClient, HyperClientLeastLoaded};
Expand All @@ -28,11 +19,7 @@ pub use client::clients_reqwest::{
pub use server::{Server, ServerBuilder};
use url::Url;

use crate::http::headers::X_FORWARDED_HOST;

/// Blanket async read+write trait for streams Box-ing
trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin> AsyncReadWrite for T {}
use crate::{http::headers::X_FORWARDED_HOST, network::AsyncReadWrite};

/// Calculate very approximate HTTP request/response headers size in bytes.
/// More or less accurate only for http/1.1 since in h2 headers are in HPACK-compressed.
Expand Down Expand Up @@ -101,70 +88,6 @@ pub fn extract_authority<T>(request: &Request<T>) -> Option<&str> {
.and_then(extract_host)
}

/// Async read+write wrapper that counts bytes read/written
struct AsyncCounter<T: AsyncReadWrite> {
inner: T,
stats: Arc<Stats>,
}

impl<T: AsyncReadWrite> AsyncCounter<T> {
/// Create new `AsyncCounter`
pub fn new(inner: T) -> (Self, Arc<Stats>) {
let stats = Arc::new(Stats::new());

(
Self {
inner,
stats: stats.clone(),
},
stats,
)
}
}

impl<T: AsyncReadWrite> AsyncRead for AsyncCounter<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let size_before = buf.filled().len();
let poll = pin!(&mut self.inner).poll_read(cx, buf);
if matches!(&poll, Poll::Ready(Ok(()))) {
let rcvd = buf.filled().len() - size_before;
self.stats.rcvd.fetch_add(rcvd as u64, Ordering::SeqCst);
}

poll
}
}

impl<T: AsyncReadWrite> AsyncWrite for AsyncCounter<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let poll = pin!(&mut self.inner).poll_write(cx, buf);
if let Poll::Ready(Ok(v)) = &poll {
self.stats.sent.fetch_add(*v as u64, Ordering::SeqCst);
}

poll
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
pin!(&mut self.inner).poll_shutdown(cx)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
pin!(&mut self.inner).poll_flush(cx)
}
}

/// Error that might happen during Url to Uri conversion
#[derive(thiserror::Error, Debug)]
pub enum UrlToUriError {
Expand Down
150 changes: 9 additions & 141 deletions ic-bn-lib/src/http/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ pub mod proxy_protocol;

use std::{
fmt::Display,
io,
net::SocketAddr,
os::unix::fs::PermissionsExt,
path::PathBuf,
sync::{
Arc,
Expand Down Expand Up @@ -39,107 +37,32 @@ use prometheus::{
use proxy_protocol::{ProxyHeader, ProxyProtocolStream};
use rustls::sign::SingleCertAndKey;
use scopeguard::defer;
use socket2::{Domain, Socket, Type};
use tokio::{
io::AsyncWriteExt,
net::{TcpListener, UnixListener, UnixSocket},
pin, select,
sync::mpsc::channel,
time::{sleep, timeout},
};
use tokio_io_timeout::TimeoutStream;
use tokio_rustls::TlsAcceptor;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use tower_service::Service;
use tracing::{debug, info, warn};
use uuid::Uuid;

use super::{AsyncCounter, AsyncReadWrite, body::NotifyingBody};
use crate::tls::{pem_convert_to_rustls, prepare_server_config};
use super::body::NotifyingBody;
use crate::{
network::{AsyncCounter, AsyncReadWrite, listener::Listener, tls_handshake},
tls::{pem_convert_to_rustls, prepare_server_config},
};

const YEAR: Duration = Duration::from_secs(86400 * 365);

/// Connection listener
pub enum Listener {
Tcp(TcpListener),
Unix(UnixListener),
}

impl Listener {
/// Create a new Listener
pub fn new(addr: Addr, opts: ListenerOpts) -> Result<Self, Error> {
Ok(match addr {
Addr::Tcp(v) => Self::Tcp(listen_tcp(v, opts)?),
Addr::Unix(v) => Self::Unix(listen_unix(v, opts)?),
})
}

/// Accept the connection
async fn accept(&self) -> Result<(Box<dyn AsyncReadWrite>, Addr), io::Error> {
Ok(match self {
Self::Tcp(v) => {
let x = v.accept().await?;
(Box::new(x.0), Addr::Tcp(x.1))
}
Self::Unix(v) => {
let x = v.accept().await?;
(
Box::new(x.0),
Addr::Unix(x.1.as_pathname().map(|x| x.into()).unwrap_or_default()),
)
}
})
}

pub fn local_addr(&self) -> Option<SocketAddr> {
match &self {
Self::Tcp(v) => v.local_addr().ok(),
Self::Unix(_) => None,
}
}
}

impl From<TcpListener> for Listener {
/// Creates a Listener from TcpListener
fn from(v: TcpListener) -> Self {
Self::Tcp(v)
}
}

impl From<UnixListener> for Listener {
/// Creates a Listener from UnixListener
fn from(v: UnixListener) -> Self {
Self::Unix(v)
}
}

#[derive(Clone)]
enum RequestState {
Start,
End,
}

async fn tls_handshake(
rustls_cfg: Arc<rustls::ServerConfig>,
stream: impl AsyncReadWrite,
) -> Result<(impl AsyncReadWrite, TlsInfo), Error> {
let tls_acceptor = TlsAcceptor::from(rustls_cfg);

// Perform the TLS handshake
let start = Instant::now();
let stream = tls_acceptor
.accept(stream)
.await
.context("TLS accept failed")?;
let duration = start.elapsed();

let conn = stream.get_ref().1;
let mut tls_info = TlsInfo::try_from(conn)?;
tls_info.handshake_dur = duration;

Ok((stream, tls_info))
}

struct Conn {
addr: Addr,
remote_addr: Addr,
Expand Down Expand Up @@ -638,7 +561,8 @@ impl Server {
keepalive: (&self.options).into(),
};

let listener = Listener::new(self.addr.clone(), opts)?;
let listener =
Listener::new(self.addr.clone(), opts).context("unable to create listener")?;
self.serve_with_listener(listener, token).await
}

Expand Down Expand Up @@ -733,64 +657,6 @@ impl Server {
}
}

/// Creates a TCP listener with given opts
pub fn listen_tcp(addr: SocketAddr, opts: ListenerOpts) -> Result<TcpListener, Error> {
let domain = if addr.is_ipv4() {
Domain::IPV4
} else {
Domain::IPV6
};

let socket = Socket::new(domain, Type::STREAM, None).context("unable to create socket")?;
socket
.set_tcp_nodelay(true)
.context("unable to set TCP_NODELAY")?;

if let Some(v) = opts.mss {
socket.set_tcp_mss(v).context("unable to set TCP MSS")?;
}

socket
.set_reuse_address(true)
.context("unable to set SO_REUSEADDR")?;
socket
.set_tcp_keepalive(&opts.keepalive)
.context("unable to set keepalive on the socket")?;
socket
.set_nonblocking(true)
.context("unable to set socket into non-blocking mode")?;

socket.bind(&addr.into()).context("unable to bind socket")?;
socket
.listen(opts.backlog as i32)
.context("unable to listen on the socket")?;

let listener = TcpListener::from_std(socket.into())
.context("unable to convert socket from the standard one")?;

Ok(listener)
}

/// Creates a Unix Socket listener with given opts
pub fn listen_unix(path: PathBuf, opts: ListenerOpts) -> Result<UnixListener, Error> {
let socket = UnixSocket::new_stream().context("unable to open UNIX socket")?;

if path.exists() {
std::fs::remove_file(&path).context("unable to remove UNIX socket")?;
}

socket.bind(&path).context("unable to bind socket")?;

let socket = socket
.listen(opts.backlog)
.context("unable to listen socket")?;

std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o666))
.context("unable to set permissions on socket")?;

Ok(socket)
}

#[async_trait]
impl Run for Server {
async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> {
Expand All @@ -803,6 +669,8 @@ impl Run for Server {
mod test {
use http::StatusCode;

use crate::network::listener::listen_tcp;

use super::*;

#[tokio::test]
Expand Down
3 changes: 3 additions & 0 deletions ic-bn-lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
#![warn(tail_expr_drop_order)]
#![allow(clippy::cognitive_complexity)]
#![allow(clippy::field_reassign_with_default)]
#![allow(clippy::collapsible_if)]

#[cfg(feature = "custom-domains")]
pub mod custom_domains;
pub mod http;
pub mod network;
pub mod pubsub;
pub mod smtp;
pub mod tasks;
pub mod tests;
pub mod tls;
Expand Down
Loading
Loading