diff --git a/Cargo.toml b/Cargo.toml index 3083046..67d3fb2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,8 @@ instant-acme = { version = "0.8.2", default-features = false, features = [ "ring", "hyper-rustls", ] } +mail-auth = "0.8.0" +mail-parser = { version = "0.11.3", features = ["full_encoding"] } mockall = "0.13.0" parse-size = { version = "1.1.0", features = ["std"] } prometheus = "0.14.0" @@ -64,6 +66,7 @@ rustls = { version = "0.23.18", default-features = false, features = [ "brotli", ] } serde = { version = "1.0.214", features = ["derive"] } +smtp-proto = "0.2.1" socket2 = "0.6.0" strum = { version = "0.28", features = ["derive"] } strum_macros = "0.28" diff --git a/ic-bn-lib/Cargo.toml b/ic-bn-lib/Cargo.toml index a69c2f4..0fcaa72 100644 --- a/ic-bn-lib/Cargo.toml +++ b/ic-bn-lib/Cargo.toml @@ -74,6 +74,8 @@ ic-bn-lib-common = "0.1" indoc = "2.0.6" instant-acme = { workspace = true, optional = true } itertools = "0.14.0" +mail-auth = { workspace = true } +mail-parser = { workspace = true } moka = { version = "0.12.15", features = ["sync", "future"] } nix = { version = "0.30.0", features = ["signal"] } ppp = "2.3.0" @@ -107,6 +109,7 @@ sev = { version = "7.1.0", optional = true, default-features = false, features = ] } sha1 = "0.11.0" sha2 = { version = "0.11.0", optional = true } +smtp-proto = { workspace = true } socket2 = "0.6.0" strum = { workspace = true } strum_macros = { workspace = true } @@ -114,6 +117,7 @@ systemstat = "0.2.3" tar = { version = "0.4.44", optional = true } tempfile = "3.23" thiserror = { workspace = true } +tracing-subscriber = "0.3" tokio = { version = "1.47.0", features = ["full"] } tokio-util = { workspace = true } tokio-rustls = { version = "0.26.0", default-features = false, features = [ @@ -145,6 +149,7 @@ mock-io = { version = "0.3.2", features = ["full"] } # Do not upgrade unless you want to upgrade `rand` too rand_regex = "=0.17.0" tempfile = "3.20.0" +tokio-test = "0.4" [[bench]] name = "vector" @@ -154,3 +159,7 @@ required-features = ["vector"] [package.metadata.cargo-all-features] # Limit feature combinations to reduce test duration max_combination_size = 2 + +[[bin]] +name = "smtp-server" +path = "tools/smtp_server.rs" diff --git a/ic-bn-lib/src/http/mod.rs b/ic-bn-lib/src/http/mod.rs index 84b8548..476a357 100644 --- a/ic-bn-lib/src/http/mod.rs +++ b/ic-bn-lib/src/http/mod.rs @@ -8,17 +8,8 @@ pub mod proxy; pub mod server; pub mod shed; -use std::{ - io, - pin::{Pin, pin}, - sync::{Arc, atomic::Ordering}, - task::{Context, Poll}, -}; - use axum::response::{IntoResponse, Redirect}; use http::{HeaderMap, Method, Request, StatusCode, Uri, Version, header::HOST, uri::PathAndQuery}; -use ic_bn_lib_common::types::http::Stats; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; #[cfg(feature = "clients-hyper")] pub use client::clients_hyper::{HyperClient, HyperClientLeastLoaded}; @@ -28,11 +19,7 @@ pub use client::clients_reqwest::{ pub use server::{Server, ServerBuilder}; use url::Url; -use crate::http::headers::X_FORWARDED_HOST; - -/// Blanket async read+write trait for streams Box-ing -trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {} -impl AsyncReadWrite for T {} +use crate::{http::headers::X_FORWARDED_HOST, network::AsyncReadWrite}; /// Calculate very approximate HTTP request/response headers size in bytes. /// More or less accurate only for http/1.1 since in h2 headers are in HPACK-compressed. @@ -101,70 +88,6 @@ pub fn extract_authority(request: &Request) -> Option<&str> { .and_then(extract_host) } -/// Async read+write wrapper that counts bytes read/written -struct AsyncCounter { - inner: T, - stats: Arc, -} - -impl AsyncCounter { - /// Create new `AsyncCounter` - pub fn new(inner: T) -> (Self, Arc) { - let stats = Arc::new(Stats::new()); - - ( - Self { - inner, - stats: stats.clone(), - }, - stats, - ) - } -} - -impl AsyncRead for AsyncCounter { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - let size_before = buf.filled().len(); - let poll = pin!(&mut self.inner).poll_read(cx, buf); - if matches!(&poll, Poll::Ready(Ok(()))) { - let rcvd = buf.filled().len() - size_before; - self.stats.rcvd.fetch_add(rcvd as u64, Ordering::SeqCst); - } - - poll - } -} - -impl AsyncWrite for AsyncCounter { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let poll = pin!(&mut self.inner).poll_write(cx, buf); - if let Poll::Ready(Ok(v)) = &poll { - self.stats.sent.fetch_add(*v as u64, Ordering::SeqCst); - } - - poll - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - pin!(&mut self.inner).poll_shutdown(cx) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - pin!(&mut self.inner).poll_flush(cx) - } -} - /// Error that might happen during Url to Uri conversion #[derive(thiserror::Error, Debug)] pub enum UrlToUriError { diff --git a/ic-bn-lib/src/http/server/mod.rs b/ic-bn-lib/src/http/server/mod.rs index ef8056f..f7be719 100644 --- a/ic-bn-lib/src/http/server/mod.rs +++ b/ic-bn-lib/src/http/server/mod.rs @@ -2,9 +2,7 @@ pub mod proxy_protocol; use std::{ fmt::Display, - io, net::SocketAddr, - os::unix::fs::PermissionsExt, path::PathBuf, sync::{ Arc, @@ -39,107 +37,32 @@ use prometheus::{ use proxy_protocol::{ProxyHeader, ProxyProtocolStream}; use rustls::sign::SingleCertAndKey; use scopeguard::defer; -use socket2::{Domain, Socket, Type}; use tokio::{ io::AsyncWriteExt, - net::{TcpListener, UnixListener, UnixSocket}, pin, select, sync::mpsc::channel, time::{sleep, timeout}, }; use tokio_io_timeout::TimeoutStream; -use tokio_rustls::TlsAcceptor; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tower_service::Service; use tracing::{debug, info, warn}; use uuid::Uuid; -use super::{AsyncCounter, AsyncReadWrite, body::NotifyingBody}; -use crate::tls::{pem_convert_to_rustls, prepare_server_config}; +use super::body::NotifyingBody; +use crate::{ + network::{AsyncCounter, AsyncReadWrite, listener::Listener, tls_handshake}, + tls::{pem_convert_to_rustls, prepare_server_config}, +}; const YEAR: Duration = Duration::from_secs(86400 * 365); -/// Connection listener -pub enum Listener { - Tcp(TcpListener), - Unix(UnixListener), -} - -impl Listener { - /// Create a new Listener - pub fn new(addr: Addr, opts: ListenerOpts) -> Result { - Ok(match addr { - Addr::Tcp(v) => Self::Tcp(listen_tcp(v, opts)?), - Addr::Unix(v) => Self::Unix(listen_unix(v, opts)?), - }) - } - - /// Accept the connection - async fn accept(&self) -> Result<(Box, Addr), io::Error> { - Ok(match self { - Self::Tcp(v) => { - let x = v.accept().await?; - (Box::new(x.0), Addr::Tcp(x.1)) - } - Self::Unix(v) => { - let x = v.accept().await?; - ( - Box::new(x.0), - Addr::Unix(x.1.as_pathname().map(|x| x.into()).unwrap_or_default()), - ) - } - }) - } - - pub fn local_addr(&self) -> Option { - match &self { - Self::Tcp(v) => v.local_addr().ok(), - Self::Unix(_) => None, - } - } -} - -impl From for Listener { - /// Creates a Listener from TcpListener - fn from(v: TcpListener) -> Self { - Self::Tcp(v) - } -} - -impl From for Listener { - /// Creates a Listener from UnixListener - fn from(v: UnixListener) -> Self { - Self::Unix(v) - } -} - #[derive(Clone)] enum RequestState { Start, End, } -async fn tls_handshake( - rustls_cfg: Arc, - stream: impl AsyncReadWrite, -) -> Result<(impl AsyncReadWrite, TlsInfo), Error> { - let tls_acceptor = TlsAcceptor::from(rustls_cfg); - - // Perform the TLS handshake - let start = Instant::now(); - let stream = tls_acceptor - .accept(stream) - .await - .context("TLS accept failed")?; - let duration = start.elapsed(); - - let conn = stream.get_ref().1; - let mut tls_info = TlsInfo::try_from(conn)?; - tls_info.handshake_dur = duration; - - Ok((stream, tls_info)) -} - struct Conn { addr: Addr, remote_addr: Addr, @@ -638,7 +561,8 @@ impl Server { keepalive: (&self.options).into(), }; - let listener = Listener::new(self.addr.clone(), opts)?; + let listener = + Listener::new(self.addr.clone(), opts).context("unable to create listener")?; self.serve_with_listener(listener, token).await } @@ -733,64 +657,6 @@ impl Server { } } -/// Creates a TCP listener with given opts -pub fn listen_tcp(addr: SocketAddr, opts: ListenerOpts) -> Result { - let domain = if addr.is_ipv4() { - Domain::IPV4 - } else { - Domain::IPV6 - }; - - let socket = Socket::new(domain, Type::STREAM, None).context("unable to create socket")?; - socket - .set_tcp_nodelay(true) - .context("unable to set TCP_NODELAY")?; - - if let Some(v) = opts.mss { - socket.set_tcp_mss(v).context("unable to set TCP MSS")?; - } - - socket - .set_reuse_address(true) - .context("unable to set SO_REUSEADDR")?; - socket - .set_tcp_keepalive(&opts.keepalive) - .context("unable to set keepalive on the socket")?; - socket - .set_nonblocking(true) - .context("unable to set socket into non-blocking mode")?; - - socket.bind(&addr.into()).context("unable to bind socket")?; - socket - .listen(opts.backlog as i32) - .context("unable to listen on the socket")?; - - let listener = TcpListener::from_std(socket.into()) - .context("unable to convert socket from the standard one")?; - - Ok(listener) -} - -/// Creates a Unix Socket listener with given opts -pub fn listen_unix(path: PathBuf, opts: ListenerOpts) -> Result { - let socket = UnixSocket::new_stream().context("unable to open UNIX socket")?; - - if path.exists() { - std::fs::remove_file(&path).context("unable to remove UNIX socket")?; - } - - socket.bind(&path).context("unable to bind socket")?; - - let socket = socket - .listen(opts.backlog) - .context("unable to listen socket")?; - - std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o666)) - .context("unable to set permissions on socket")?; - - Ok(socket) -} - #[async_trait] impl Run for Server { async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> { @@ -803,6 +669,8 @@ impl Run for Server { mod test { use http::StatusCode; + use crate::network::listener::listen_tcp; + use super::*; #[tokio::test] diff --git a/ic-bn-lib/src/lib.rs b/ic-bn-lib/src/lib.rs index 19e3493..02e7f5b 100644 --- a/ic-bn-lib/src/lib.rs +++ b/ic-bn-lib/src/lib.rs @@ -4,11 +4,14 @@ #![warn(tail_expr_drop_order)] #![allow(clippy::cognitive_complexity)] #![allow(clippy::field_reassign_with_default)] +#![allow(clippy::collapsible_if)] #[cfg(feature = "custom-domains")] pub mod custom_domains; pub mod http; +pub mod network; pub mod pubsub; +pub mod smtp; pub mod tasks; pub mod tests; pub mod tls; diff --git a/ic-bn-lib/src/network/listener.rs b/ic-bn-lib/src/network/listener.rs new file mode 100644 index 0000000..b5403d8 --- /dev/null +++ b/ic-bn-lib/src/network/listener.rs @@ -0,0 +1,100 @@ +use std::{io, net::SocketAddr, os::unix::fs::PermissionsExt, path::PathBuf}; + +use ic_bn_lib_common::types::http::{Addr, ListenerOpts}; +use socket2::{Domain, Socket, Type}; +use tokio::net::{TcpListener, UnixListener, UnixSocket}; + +use crate::network::AsyncReadWrite; + +/// Generic connection listener +pub enum Listener { + Tcp(TcpListener), + Unix(UnixListener), +} + +impl Listener { + /// Create a new Listener + pub fn new(addr: Addr, opts: ListenerOpts) -> io::Result { + Ok(match addr { + Addr::Tcp(v) => Self::Tcp(listen_tcp(v, opts)?), + Addr::Unix(v) => Self::Unix(listen_unix(v, opts)?), + }) + } + + /// Accept the connection + pub async fn accept(&self) -> io::Result<(Box, Addr)> { + Ok(match self { + Self::Tcp(v) => { + let x = v.accept().await?; + (Box::new(x.0), Addr::Tcp(x.1)) + } + Self::Unix(v) => { + let x = v.accept().await?; + ( + Box::new(x.0), + Addr::Unix(x.1.as_pathname().map(|x| x.into()).unwrap_or_default()), + ) + } + }) + } + + pub fn local_addr(&self) -> Option { + match &self { + Self::Tcp(v) => v.local_addr().ok(), + Self::Unix(_) => None, + } + } +} + +impl From for Listener { + /// Creates a Listener from TcpListener + fn from(v: TcpListener) -> Self { + Self::Tcp(v) + } +} + +impl From for Listener { + /// Creates a Listener from UnixListener + fn from(v: UnixListener) -> Self { + Self::Unix(v) + } +} + +/// Creates a TCP listener with given opts +pub fn listen_tcp(addr: SocketAddr, opts: ListenerOpts) -> io::Result { + let domain = if addr.is_ipv4() { + Domain::IPV4 + } else { + Domain::IPV6 + }; + + let socket = Socket::new(domain, Type::STREAM, None)?; + socket.set_tcp_nodelay(true)?; + + if let Some(v) = opts.mss { + socket.set_tcp_mss(v)?; + } + + socket.set_reuse_address(true)?; + socket.set_tcp_keepalive(&opts.keepalive)?; + socket.set_nonblocking(true)?; + socket.bind(&addr.into())?; + socket.listen(opts.backlog as i32)?; + + TcpListener::from_std(socket.into()) +} + +/// Creates a Unix Socket listener with given opts +pub fn listen_unix(path: PathBuf, opts: ListenerOpts) -> io::Result { + let socket = UnixSocket::new_stream()?; + + if path.exists() { + std::fs::remove_file(&path)?; + } + + socket.bind(&path)?; + let socket = socket.listen(opts.backlog)?; + std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o666))?; + + Ok(socket) +} diff --git a/ic-bn-lib/src/network/mod.rs b/ic-bn-lib/src/network/mod.rs new file mode 100644 index 0000000..bd39cc9 --- /dev/null +++ b/ic-bn-lib/src/network/mod.rs @@ -0,0 +1,101 @@ +use std::{ + io, + pin::{Pin, pin}, + sync::{Arc, atomic::Ordering}, + task::{Context, Poll}, + time::Instant, +}; + +use ic_bn_lib_common::types::http::{Stats, TlsInfo}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::{TlsAcceptor, server::TlsStream}; + +pub mod listener; + +/// Blanket async read+write trait for streams `Box`-ing +pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {} +impl AsyncReadWrite for T {} + +/// Performs TLS handshake on the given stream +pub async fn tls_handshake( + rustls_cfg: Arc, + stream: T, +) -> io::Result<(TlsStream, TlsInfo)> { + let tls_acceptor = TlsAcceptor::from(rustls_cfg); + + // Perform the TLS handshake + let start = Instant::now(); + let stream = tls_acceptor.accept(stream).await?; + let duration = start.elapsed(); + + // Obtain TLS info + let conn = stream.get_ref().1; + let mut tls_info = TlsInfo::try_from(conn).map_err(io::Error::other)?; + tls_info.handshake_dur = duration; + + Ok((stream, tls_info)) +} + +/// Async read+write wrapper that counts bytes read/written +pub struct AsyncCounter { + inner: T, + stats: Arc, +} + +impl AsyncCounter { + /// Create new `AsyncCounter` + pub fn new(inner: T) -> (Self, Arc) { + let stats = Arc::new(Stats::new()); + + ( + Self { + inner, + stats: stats.clone(), + }, + stats, + ) + } +} + +impl AsyncRead for AsyncCounter { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let size_before = buf.filled().len(); + let poll = pin!(&mut self.inner).poll_read(cx, buf); + if matches!(&poll, Poll::Ready(Ok(()))) { + let rcvd = buf.filled().len() - size_before; + self.stats.rcvd.fetch_add(rcvd as u64, Ordering::SeqCst); + } + + poll + } +} + +impl AsyncWrite for AsyncCounter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let poll = pin!(&mut self.inner).poll_write(cx, buf); + if let Poll::Ready(Ok(v)) = &poll { + self.stats.sent.fetch_add(*v as u64, Ordering::SeqCst); + } + + poll + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + pin!(&mut self.inner).poll_shutdown(cx) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(&mut self.inner).poll_flush(cx) + } +} diff --git a/ic-bn-lib/src/smtp/address.rs b/ic-bn-lib/src/smtp/address.rs new file mode 100644 index 0000000..5d76b6d --- /dev/null +++ b/ic-bn-lib/src/smtp/address.rs @@ -0,0 +1,83 @@ +use std::{fmt::Display, str::FromStr}; + +use derive_new::new; +use fqdn::FQDN; + +#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq)] +pub enum EmailAddressError { + #[error("@ is missing")] + AtMissing, + #[error("Domain incorrect: {0}")] + DomainIncorrect(String), +} + +/// E-Mail address representation. +/// +/// Currently we don't validate the local part at all +/// and just consider everything to the right from the +/// rightmost @ as a domain part. +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, new)] +pub struct EmailAddress { + pub local: String, + pub domain: FQDN, +} + +impl Display for EmailAddress { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}@{}", self.local, self.domain) + } +} + +impl FromStr for EmailAddress { + type Err = EmailAddressError; + + fn from_str(s: &str) -> Result { + let (local, domain) = s.rsplit_once('@').ok_or(EmailAddressError::AtMissing)?; + if domain.is_empty() { + return Err(EmailAddressError::DomainIncorrect("Empty domain".into())); + } + + let domain = FQDN::from_ascii_str(domain) + .map_err(|e| EmailAddressError::DomainIncorrect(e.to_string()))?; + + Ok(Self { + local: local.into(), + domain, + }) + } +} + +impl TryFrom<&str> for EmailAddress { + type Error = EmailAddressError; + + fn try_from(value: &str) -> Result { + Self::from_str(value) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_email_address() { + // ok + for v in ["foo@bar", "john.doe@jane.doe", "\"foo+bar@baz\"@dead.beef"] { + assert_eq!(EmailAddress::from_str(v).unwrap().to_string(), v); + } + + // no @ + assert_eq!( + EmailAddress::from_str("foo").unwrap_err(), + EmailAddressError::AtMissing + ); + + // bad domain + for v in ["foo@bar\"baz", "\"jane@doe\""] { + assert!(matches!( + EmailAddress::from_str(v).unwrap_err(), + EmailAddressError::DomainIncorrect(_) + )); + } + } +} diff --git a/ic-bn-lib/src/smtp/ic/candid.rs b/ic-bn-lib/src/smtp/ic/candid.rs new file mode 100644 index 0000000..8ef72c1 --- /dev/null +++ b/ic-bn-lib/src/smtp/ic/candid.rs @@ -0,0 +1,57 @@ +//! Candid types for an IC SMTP Protocol + +use candid::{CandidType, Deserialize}; + +/// Candid `Header`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct Header { + pub name: String, + pub value: String, +} + +/// Candid `Message`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct Message { + pub headers: Vec
, + pub body: Vec, +} + +/// Candid `Address`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct Address { + pub user: String, + pub domain: String, +} + +/// Candid `Envelope`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct Envelope { + pub from: Address, + pub to: Address, +} + +/// Candid `SmtpRequest`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct SmtpRequest { + pub message: Option, + pub envelope: Option, + pub gateway_flags: Option>, +} + +/// Candid `SmtpRequestError` (`code` is `nat64` on the wire in typical canisters). +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct SmtpRequestError { + pub code: u64, + pub message: String, +} + +/// Candid `SmtpResponse` — `Ok` carries an empty record. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub enum SmtpResponse { + Ok(SmtpOk), + Err(SmtpRequestError), +} + +/// Empty record for variant `Ok`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct SmtpOk {} diff --git a/ic-bn-lib/src/smtp/ic/mod.rs b/ic-bn-lib/src/smtp/ic/mod.rs new file mode 100644 index 0000000..938a4a0 --- /dev/null +++ b/ic-bn-lib/src/smtp/ic/mod.rs @@ -0,0 +1 @@ +pub mod candid; diff --git a/ic-bn-lib/src/smtp/inbound/ehlo.rs b/ic-bn-lib/src/smtp/inbound/ehlo.rs new file mode 100644 index 0000000..cd4b7d8 --- /dev/null +++ b/ic-bn-lib/src/smtp/inbound/ehlo.rs @@ -0,0 +1,90 @@ +use std::str::FromStr; + +use fqdn::FQDN; +use mail_auth::hickory_resolver::proto::ProtoErrorKind; +use smtp_proto::{ + EXT_8BIT_MIME, EXT_CHUNKING, EXT_ENHANCED_STATUS_CODES, EXT_SMTP_UTF8, EXT_START_TLS, + EhloResponse, +}; + +use crate::{ + network::AsyncReadWrite, + smtp::inbound::{Session, SessionResult}, +}; + +impl Session { + /// Handles EHLO/HELO commands + pub async fn handle_ehlo(&mut self, host: &str, extended: bool) -> SessionResult<()> { + // Validate hostname + let Ok(ehlo_hostname) = FQDN::from_str(host) else { + return self.reply("550", "5.5.0", "Invalid EHLO hostname.").await; + }; + + // If EHLO hostname is already set to the same value - just reply directly, + // avoid redundant checks + if let Some(v) = &self.data.ehlo_hostname + && v == &ehlo_hostname + { + return self.send_ehlo(extended).await; + } + + if ehlo_hostname.depth() < 2 { + return self + .reply("550", "5.5.0", "EHLO hostname must be an FQDN.") + .await; + }; + + // Check if EHLO hostname resolves if configured + if self.cfg.verify_ehlo_hostname { + match self.cfg.authenticator.resolver().lookup_ip(host).await { + Ok(v) => { + if v.iter().next().is_none() { + return self + .reply("550", "5.5.0", "EHLO hostname not found in DNS.") + .await; + } + } + + Err(e) => { + if matches!(e.kind(), ProtoErrorKind::NoRecordsFound(_)) { + return self + .reply("550", "5.5.0", "EHLO hostname not found in DNS.") + .await; + } + + return self + .reply("451", "4.7.25", "Temporary error validating EHLO hostname.") + .await; + } + } + } + + self.reset_message(); + self.data.ehlo_hostname = Some(ehlo_hostname); + + return self.send_ehlo(extended).await; + } + + async fn send_ehlo(&mut self, extended: bool) -> SessionResult<()> { + if !extended { + return self + .write(format!("250 {} you had me at HELO\r\n", self.cfg.hostname).as_bytes()) + .await; + } + + let mut response = EhloResponse::new(self.cfg.hostname.as_str()); + response.capabilities = + EXT_ENHANCED_STATUS_CODES | EXT_8BIT_MIME | EXT_SMTP_UTF8 | EXT_CHUNKING; + response.size = self.cfg.max_message_size; + + // Send STARTTLS cap only if we support TLS & we're not already in TLS mode + if self.tls_info.is_none() && self.cfg.tls_mode.enabled() { + response.capabilities |= EXT_START_TLS; + } + + let mut buf = Vec::with_capacity(128); + response.write(&mut buf).ok(); + + self.write(&buf).await + } +} diff --git a/ic-bn-lib/src/smtp/inbound/mail_from.rs b/ic-bn-lib/src/smtp/inbound/mail_from.rs new file mode 100644 index 0000000..6a81463 --- /dev/null +++ b/ic-bn-lib/src/smtp/inbound/mail_from.rs @@ -0,0 +1,125 @@ +use std::{borrow::Cow, fmt::Write, str::FromStr}; + +use mail_auth::{IprevResult, Parameters, SpfResult, spf::verify::SpfParameters}; +use smtp_proto::{MAIL_BY_NOTIFY, MAIL_BY_RETURN, MailFrom}; + +use crate::{ + network::AsyncReadWrite, + smtp::{ + address::EmailAddress, + inbound::{Session, SessionResult}, + }, +}; + +impl Session { + /// Handles MAIL FROM command + pub async fn handle_mail_from(&mut self, from: MailFrom>) -> SessionResult<()> { + let Some(helo_hostname) = &self.data.ehlo_hostname else { + return self + .reply("503", "5.5.1", "Polite people say EHLO first.") + .await; + }; + + if self.data.mail_from.is_some() { + return self + .reply( + "503", + "5.5.1", + "Multiple MAIL FROM commands are not allowed.", + ) + .await; + } + + if self.cfg.tls_mode.required() && self.tls_info.is_none() { + return self + .reply( + "503", + "5.5.1", + "TLS is required to submit mail on this server.", + ) + .await; + } + + if (from.flags & (MAIL_BY_NOTIFY | MAIL_BY_RETURN)) != 0 { + return self.ext_unsupported("DELIVERBY").await; + } + + if from.mt_priority != 0 { + return self.ext_unsupported("MT-PRIORITY").await; + } + + if from.size > self.cfg.max_message_size { + return self.message_too_big().await; + } + + if from.hold_for != 0 || from.hold_until != 0 { + return self.ext_unsupported("FUTURERELEASE").await; + } + + if from.env_id.is_some() { + return self.ext_unsupported("DSN").await; + } + + // Validate address + let Ok(address) = EmailAddress::from_str(&from.address) else { + return self + .reply("550", "5.7.1", "Sender address is incorrect.") + .await; + }; + + // Validate reverse IP if configured + if self.cfg.verify_reverse_ip { + let result = self + .cfg + .authenticator + .verify_iprev(Parameters::from(self.remote_ip)) + .await + .result; + + if !matches!(result, IprevResult::Pass) { + let (code, ext, msg) = if matches!(result, IprevResult::TempError(_)) { + ("451", "4.7.25", "Temporary error validating reverse DNS.") + } else { + ("550", "5.7.25", "Reverse DNS validation failed.") + }; + + return self.reply(code, ext, msg).await; + } + } + + if self.cfg.verify_spf { + let output = self + .cfg + .authenticator + .verify_spf(SpfParameters::verify_mail_from( + self.remote_ip, + &helo_hostname.to_string(), + &self.cfg.hostname, + &from.address, + )) + .await; + + match output.result() { + SpfResult::Pass | SpfResult::Neutral | SpfResult::None => {} + SpfResult::TempError => { + return self + .reply("451", "4.7.24", "Temporary SPF validation error.") + .await; + } + SpfResult::Fail | SpfResult::PermError | SpfResult::SoftFail => { + let mut msg = "SPF validation failed".to_string(); + if let Some(v) = output.explanation() { + write!(msg, ": {v}.").ok(); + } + write!(msg, "\r\n").ok(); + + return self.reply("550", "5.7.23", &msg).await; + } + } + } + + self.reply("250", "2.1.0", "OK").await?; + self.data.mail_from = Some(address); + Ok(()) + } +} diff --git a/ic-bn-lib/src/smtp/inbound/manager.rs b/ic-bn-lib/src/smtp/inbound/manager.rs new file mode 100644 index 0000000..d8f8a8d --- /dev/null +++ b/ic-bn-lib/src/smtp/inbound/manager.rs @@ -0,0 +1,106 @@ +use std::{net::SocketAddr, sync::Arc}; + +use tokio::io::AsyncWriteExt; +use tokio_rustls::server::TlsStream; +use tokio_util::sync::CancellationToken; +use tracing::{debug, info}; + +use crate::{ + network::{AsyncReadWrite, tls_handshake}, + smtp::inbound::{ + Session, SessionConfig, SessionData, SessionError, SessionResult, SessionTlsMode, + SessionUpgrade, + }, +}; + +/// Manages the lifetime of a single SMTP session. +/// +/// It's needed because the SMTP session can transition into TLS state +/// which requires external orchestration. +pub struct SessionManager; + +impl SessionManager { + pub async fn handle_connection( + stream: S, + remote_addr: SocketAddr, + params: Arc, + shutdown_token: CancellationToken, + ) { + let mut session = Session::new(remote_addr.ip(), stream, params); + + match session.handle(shutdown_token.child_token()).await { + Ok(v) => match v { + SessionUpgrade::No => { + session.stream.shutdown().await.ok(); + } + + SessionUpgrade::StartTls => { + Self::starttls(session, shutdown_token.child_token()).await + } + }, + + Err(e) => { + if !matches!(e, SessionError::Quit) { + info!("{session}: error: {e:#}"); + } + + if let Err(e) = session.shutdown().await { + debug!("{session}: error closing connection: {e:#}"); + }; + } + } + } + + /// Converts session into TLS mode + async fn starttls(session: Session, shutdown_token: CancellationToken) { + let session_name = session.to_string(); + + match session.into_tls().await { + Ok(mut session) => { + if let Err(e) = session.handle(shutdown_token.child_token()).await { + if !matches!(e, SessionError::Quit) { + info!("{session}: error: {e:#}"); + } + + if let Err(e) = session.shutdown().await { + debug!("{session}: error closing connection: {e:#}"); + }; + } + } + + Err(e) => { + info!("{session_name}: TLS handshake failed: {e:#}"); + } + }; + } +} + +impl Session { + /// Converts the plain-text session into a TLS one by doing a TLS handshake + pub async fn into_tls(self) -> SessionResult>> { + // SAFETY: We should end up here only if TLS is enabled. + // It's better to panic otherwise. + let tls_config = match &self.cfg.tls_mode { + SessionTlsMode::Allowed(v) | SessionTlsMode::Required(v) => v.clone(), + SessionTlsMode::Disabled => { + unreachable!("Session::into_tls() called with TLS disabled") + } + }; + + let (stream, tls_info) = tls_handshake(tls_config, self.stream).await?; + + Ok(Session { + id: self.id, + remote_ip: self.remote_ip, + stream, + state: self.state, + // According to the RFC we need to discard all session data + // after switching into TLS mode. + // https://datatracker.ietf.org/doc/html/rfc3207#section-4.2 + data: SessionData::default(), + counters: self.counters, + cfg: self.cfg, + tls_info: Some(tls_info), + }) + } +} diff --git a/ic-bn-lib/src/smtp/inbound/mod.rs b/ic-bn-lib/src/smtp/inbound/mod.rs new file mode 100644 index 0000000..20c5e59 --- /dev/null +++ b/ic-bn-lib/src/smtp/inbound/mod.rs @@ -0,0 +1,754 @@ +pub mod ehlo; +pub mod mail_from; +pub mod manager; +pub mod rcpt_to; +pub mod session; + +use std::{ + fmt::Display, + io, + net::IpAddr, + sync::Arc, + time::{Duration, Instant}, +}; + +use bytes::Bytes; +use fqdn::FQDN; +use ic_bn_lib_common::types::http::TlsInfo; +use mail_auth::MessageAuthenticator; +use rustls::ServerConfig; +use smtp_proto::{ + Error as SmtpError, + request::receiver::{ + BdatReceiver, DataReceiver, DummyDataReceiver, DummyLineReceiver, RequestReceiver, + }, +}; +use strum::Display; +use uuid::Uuid; + +use crate::{ + network::AsyncReadWrite, + smtp::{ + DeliversMail, DummyDeliveryAgent, DummyRecipientResolver, ResolvesRecipient, + address::EmailAddress, + }, +}; + +#[derive(thiserror::Error, Debug)] +pub enum SessionError { + #[error("I/O error: {0}")] + Io(#[from] io::Error), + #[error("Timed out")] + Timeout, + #[error("{0}")] + SmtpError(#[from] SmtpError), + #[error("Session terminated by client (QUIT)")] + Quit, + #[error("Client is sending before greeting")] + SendsBeforeGreeting, + #[error("Too many messages per session")] + TooManyMessagesPerSession, + #[error("Session transfer quota ({0} bytes) was exceeded")] + TransferQuotaExceeded(usize), + #[error("Session TTL ({0}s) was exceeded")] + TtlExceeded(u64), + #[error("Too many errors")] + TooManyErrors, + #[error("{0}")] + Other(#[from] anyhow::Error), +} + +/// Indicates if a session needs to be upgraded to TLS +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionUpgrade { + No, + StartTls, +} + +pub type SessionResult = Result; + +/// Session TLS mode +pub enum SessionTlsMode { + Disabled, + Allowed(Arc), + Required(Arc), +} + +impl SessionTlsMode { + pub const fn enabled(&self) -> bool { + matches!(self, Self::Allowed(_) | Self::Required(_)) + } + + pub const fn required(&self) -> bool { + matches!(self, Self::Required(_)) + } +} + +/// SMTP session config +pub struct SessionConfig { + hostname: String, + greeting: Bytes, + + pub max_message_size: usize, + pub max_recipients: usize, + pub max_session_duration: Duration, + pub max_session_data: usize, + pub max_errors: usize, + pub max_messages_per_session: usize, + + pub verify_ehlo_hostname: bool, + pub verify_sender_domain: bool, + pub verify_reverse_ip: bool, + pub verify_spf: bool, + pub helo_delay: Option, + + pub timeout: Duration, + pub tls_mode: SessionTlsMode, + + pub authenticator: Arc, + pub recipient_resolver: Arc, + pub delivery_agent: Arc, +} + +impl SessionConfig { + pub fn new(hostname: &str) -> Self { + let greeting = format!("220 {hostname} ESMTP IC SMTP Gateway\r\n"); + + Self { + hostname: hostname.into(), + greeting: Bytes::from(greeting), + max_message_size: 10 * 1024 * 1024, + max_recipients: 5, + max_session_duration: Duration::from_secs(600), + max_session_data: 50 * 1024 * 1024, + max_errors: 5, + max_messages_per_session: 5, + verify_ehlo_hostname: false, + verify_reverse_ip: false, + verify_sender_domain: false, + verify_spf: false, + helo_delay: None, + timeout: Duration::from_secs(30), + tls_mode: SessionTlsMode::Disabled, + // SAFETY: this never fails + authenticator: Arc::new(MessageAuthenticator::new_cloudflare().unwrap()), + recipient_resolver: Arc::new(DummyRecipientResolver), + delivery_agent: Arc::new(DummyDeliveryAgent), + } + } +} + +/// SMTP session state +#[derive(Display)] +pub enum SessionState { + /// Need to send greeting + Greeting, + /// Default - command/response + Request(RequestReceiver), + /// ASCII data reception + Data(DataReceiver), + /// Binary data reception + Bdat(BdatReceiver), + /// Too long request received - blackhole + RequestTooLarge(DummyLineReceiver), + /// Too large data received - blackhole + DataTooLarge(DummyDataReceiver), + /// Dummy + None, +} + +impl Default for SessionState { + fn default() -> Self { + Self::Request(RequestReceiver::default()) + } +} + +/// SMTP dynamic session data +#[derive(Debug, Default)] +pub struct SessionData { + pub ehlo_hostname: Option, + pub mail_from: Option, + pub rcpt_to: Vec, + pub message: Vec, +} + +/// SMTP session counters +#[derive(Debug)] +pub struct SessionCounters { + valid_until: Instant, + bytes_ingested: usize, + messages_queued: usize, + errors: usize, +} + +impl SessionCounters { + fn new(ttl: Duration) -> Self { + Self { + valid_until: Instant::now() + ttl, + bytes_ingested: 0, + messages_queued: 0, + errors: 0, + } + } +} + +/// SMTP Session +pub struct Session { + id: Uuid, + remote_ip: IpAddr, + stream: S, + state: SessionState, + data: SessionData, + counters: SessionCounters, + cfg: Arc, + tls_info: Option, +} + +impl Display for Session { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "SMTP/Session({}){}", + self.remote_ip, + if self.tls_info.is_some() { "/TLS" } else { "" } + ) + } +} + +impl Session { + pub fn new(remote_ip: IpAddr, stream: S, cfg: Arc) -> Self { + Self { + id: Uuid::now_v7(), + remote_ip, + stream, + state: SessionState::Greeting, + data: SessionData::default(), + counters: SessionCounters::new(cfg.max_session_duration), + cfg, + tls_info: None, + } + } +} + +#[cfg(test)] +mod tests { + use std::{net::SocketAddr, str::FromStr}; + + use async_trait::async_trait; + use fqdn::fqdn; + use rustls::{ClientConfig, pki_types::ServerName}; + use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; + use tokio_rustls::TlsConnector; + use tokio_util::sync::CancellationToken; + + use crate::{ + smtp::{ + DeliveryError, Message, RecipientPolicy, RecipientResolveError, + inbound::manager::SessionManager, + }, + tests::{TEST_CERT_1, TEST_KEY_1}, + tls::{resolver::StubResolver, verify::NoopServerCertVerifier}, + }; + + use super::*; + + #[derive(Debug)] + pub struct TestDeliveryAgent(Option, Option); + + #[async_trait] + impl DeliversMail for TestDeliveryAgent { + async fn deliver_mail(&self, message: Message) -> Result<(), DeliveryError> { + if let Some(e) = &self.1 { + return Err(e.clone()); + } + + if let Some(v) = &self.0 { + assert_eq!(v, &message); + } + + Ok(()) + } + } + + #[derive(Debug)] + pub struct TestRecipientResolver( + EmailAddress, + Option, + Option>, + ); + + #[async_trait] + impl ResolvesRecipient for TestRecipientResolver { + async fn resolve_recipient( + &self, + rcpt: &EmailAddress, + ) -> Result { + assert_eq!(rcpt, &self.0); + if let Some(v) = &self.1 { + return Ok(RecipientPolicy::Rewrite(v.clone())); + } + + if let Some(v) = &self.2 { + return Ok(RecipientPolicy::Expand(v.clone())); + } + + Ok(RecipientPolicy::Accept) + } + } + + fn create_session(stream: S, helo_delay: Option) -> Session { + let mut cfg = SessionConfig::new("test"); + cfg.max_errors = 5; + cfg.max_message_size = 512; + cfg.helo_delay = helo_delay; + cfg.max_messages_per_session = 3; + cfg.max_session_data = 8192; + cfg.max_recipients = 3; + + Session::new(IpAddr::from_str("1.1.1.1").unwrap(), stream, Arc::new(cfg)) + } + + fn create_basic_stream() -> tokio_test::io::Builder { + let mut builder = tokio_test::io::Builder::new(); + + builder.write(b"220 test ESMTP IC SMTP Gateway\r\n") + .read(b"HELO foo.bar\r\n") + .write(b"250 test you had me at HELO\r\n") + .read(b"EHLO foo.bar\r\n") + .write(b"250-test you had me at EHLO\r\n250-SMTPUTF8\r\n250-ENHANCEDSTATUSCODES\r\n250-CHUNKING\r\n250 8BITMIME\r\n"); + + builder + } + + fn stream_send_message(b: &mut tokio_test::io::Builder) { + b.read(b"MAIL FROM:\r\n") + .write(b"250 2.1.0 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"DATA\r\n") + .write(b"354 Start mail input; end with .\r\n") + .read(b"foobarmessage\r\n.\r\n") + .write(b"250 2.0.0 Message (13 bytes) queued with id 00000000-0000-0000-0000-000000000000\r\n"); + } + + #[tokio::test] + async fn test_ehlo_required() { + let stream = tokio_test::io::Builder::new() + .write(b"220 test ESMTP IC SMTP Gateway\r\n") + .read(b"MAIL FROM:\r\n") + .write(b"503 5.5.1 Polite people say EHLO first.\r\n") + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + + #[tokio::test] + async fn test_basic_session() { + let mut builder = create_basic_stream(); + builder + .read(b"RCPT TO:\r\n") + .write(b"503 5.5.1 MAIL FROM is required first.\r\n") + .read(b"MAIL FROM:\r\n") + .write(b"250 2.1.0 OK\r\n") + .read(b"MAIL FROM:\r\n") + .write(b"503 5.5.1 Multiple MAIL FROM commands are not allowed.\r\n") + .read(b"DATA\r\n") + .write(b"503 5.5.1 RCPT TO is required first.\r\n") + .read(b"RSET\r\n") + .write(b"250 2.0.0 OK\r\n") + .read(b"NOOP\r\n") + .write(b"250 2.0.0 OK\r\n") + .read(b"FOOB\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"HELP\r\n") + .write(b"502 5.5.1 Command not implemented.\r\n"); + + stream_send_message(&mut builder); + let stream = builder + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + + #[tokio::test] + async fn test_client_sends_before_greeting() { + let stream = tokio_test::io::Builder::new() + .read(b"EHLO foo.bar\r\n") + .write(b"501 5.7.1 Client sent command before greeting banner.\r\n") + .build(); + + let mut session = create_session(stream, Some(Duration::from_millis(100))); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::SendsBeforeGreeting + )); + } + + #[tokio::test] + async fn test_bdat() { + let stream = create_basic_stream() + .read(b"MAIL FROM:\r\n") + .write(b"250 2.1.0 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"BDAT 10\r\n") + .read(b"01234") + .read(b"56789") + .write(b"250 2.6.0 Chunk accepted.\r\n") + .read(b"BDAT 10\r\n") + .read(b"987") + .read(b"654") + .read(b"3210") + .write(b"250 2.6.0 Chunk accepted.\r\n") + .read(b"BDAT 10 LAST\r\n") + .read(b"0123456789") + .write(b"250 2.0.0 Message (30 bytes) queued with id 00000000-0000-0000-0000-000000000000\r\n") + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let agent = TestDeliveryAgent( + Some(Message { + id: Uuid::nil(), + ehlo_hostname: fqdn!("foo.bar"), + mail_from: "foo@bar".try_into().unwrap(), + rcpt_to: vec!["bar@baz".try_into().unwrap()], + body: b"012345678998765432100123456789".to_vec(), + }), + None, + ); + let resolver = TestRecipientResolver( + "dead@beef".try_into().unwrap(), + Some("bar@baz".try_into().unwrap()), + None, + ); + + let mut cfg = SessionConfig::new("test"); + cfg.delivery_agent = Arc::new(agent); + cfg.recipient_resolver = Arc::new(resolver); + + let mut session = Session::new(IpAddr::from_str("1.1.1.1").unwrap(), stream, Arc::new(cfg)); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + + #[tokio::test] + async fn test_data() { + let mut builder = create_basic_stream(); + stream_send_message(&mut builder); + + let stream = builder + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let agent = TestDeliveryAgent( + Some(Message { + id: Uuid::nil(), + ehlo_hostname: fqdn!("foo.bar"), + mail_from: EmailAddress::from_str("foo@bar").unwrap(), + rcpt_to: vec![EmailAddress::from_str("bar@baz").unwrap()], + body: b"foobarmessage".to_vec(), + }), + None, + ); + let resolver = TestRecipientResolver( + "dead@beef".try_into().unwrap(), + Some("bar@baz".try_into().unwrap()), + None, + ); + + let mut cfg = SessionConfig::new("test"); + cfg.delivery_agent = Arc::new(agent); + cfg.recipient_resolver = Arc::new(resolver); + + let mut session = Session::new(IpAddr::from_str("1.1.1.1").unwrap(), stream, Arc::new(cfg)); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + + #[tokio::test] + async fn test_expand() { + let mut builder = create_basic_stream(); + stream_send_message(&mut builder); + + let stream = builder + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let agent = TestDeliveryAgent( + Some(Message { + id: Uuid::nil(), + ehlo_hostname: fqdn!("foo.bar"), + mail_from: EmailAddress::from_str("foo@bar").unwrap(), + rcpt_to: vec![ + EmailAddress::from_str("dead@beef").unwrap(), + EmailAddress::from_str("dead@dead").unwrap(), + EmailAddress::from_str("bar@bax").unwrap(), + ], + body: b"foobarmessage".to_vec(), + }), + None, + ); + let resolver = TestRecipientResolver( + "dead@beef".try_into().unwrap(), + None, + Some(vec![ + "dead@dead".try_into().unwrap(), + "bar@bax".try_into().unwrap(), + ]), + ); + + let mut cfg = SessionConfig::new("test"); + cfg.delivery_agent = Arc::new(agent); + cfg.recipient_resolver = Arc::new(resolver); + + let mut session = Session::new(IpAddr::from_str("1.1.1.1").unwrap(), stream, Arc::new(cfg)); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + + #[tokio::test] + async fn test_max_recipients() { + let stream = create_basic_stream() + .read(b"MAIL FROM:\r\n") + .write(b"250 2.1.0 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"455 4.5.3 Too many recipients.\r\n") + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + + #[tokio::test] + async fn test_max_message_size() { + let stream = create_basic_stream() + .read(b"MAIL FROM:\r\n") + .write(b"250 2.1.0 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"DATA\r\n") + .write(b"354 Start mail input; end with .\r\n") + .read(format!("{}\r\n.\r\n", "1".repeat(513)).as_bytes()) + .write(b"552 5.3.4 Message too big for, we accept up to 512 bytes.\r\n") + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + + #[tokio::test] + async fn test_max_messages_per_session() { + let mut builder = create_basic_stream(); + stream_send_message(&mut builder); + stream_send_message(&mut builder); + stream_send_message(&mut builder); + + let stream = builder + .read(b"MAIL FROM:\r\n") + .write(b"250 2.1.0 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"DATA\r\n") + .write(b"452 4.4.5 Maximum number of messages per session exceeded.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::TooManyMessagesPerSession + )); + } + + #[tokio::test] + async fn test_max_errors() { + let stream = tokio_test::io::Builder::new() + .write(b"220 test ESMTP IC SMTP Gateway\r\n") + .read(b"FOO\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"FOO\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"FOO\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"FOO\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"FOO\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"FOO\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"FOO\r\n") + .write(b"452 4.3.2 Too many errors.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::TooManyErrors + )); + } + + #[tokio::test] + async fn test_request_too_large() { + let stream = tokio_test::io::Builder::new() + .write(b"220 test ESMTP IC SMTP Gateway\r\n") + .read(format!("EHLO {}", "1".repeat(2048)).as_bytes()) + .read(format!("{}\r\n", "1".repeat(2048)).as_bytes()) + .write(b"554 5.3.4 Line is too long.\r\n") + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + + #[tokio::test] + async fn test_max_session_transfer_quota() { + let stream = tokio_test::io::Builder::new() + .write(b"220 test ESMTP IC SMTP Gateway\r\n") + .read(format!("EHLO {}\r\n", "1".repeat(8192)).as_bytes()) + .write(b"452 4.7.28 Session transfer quota exceeded.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::TransferQuotaExceeded(_) + )); + } + + #[tokio::test] + async fn test_starttls() { + rustls::crypto::ring::default_provider() + .install_default() + .ok(); + + // Use an in-memory pipe + let (stream1, mut stream2) = duplex(128); + + let rustls_server_cfg = ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(Arc::new( + StubResolver::new(TEST_CERT_1.as_bytes(), TEST_KEY_1.as_bytes()).unwrap(), + )); + + let mut cfg = SessionConfig::new("test"); + cfg.tls_mode = SessionTlsMode::Required(Arc::new(rustls_server_cfg)); + + tokio::spawn(async move { + SessionManager::handle_connection( + stream1, + SocketAddr::from_str("1.1.1.1:123").unwrap(), + Arc::new(cfg), + CancellationToken::new(), + ) + .await; + }); + + let mut buf = vec![0; 256]; + + let r = stream2.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"220 test ESMTP IC SMTP Gateway\r\n"); + + // Make sure there's a 250-STARTTLS in EHLO + stream2.write_all(b"EHLO foo.bar\r\n").await.unwrap(); + let r = stream2.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"250-test you had me at EHLO\r\n250-STARTTLS\r\n250-SMTPUTF8\r\n250-ENHANCEDSTATUSCODES\r\n250-CHUNKING\r\n250 8BITMIME\r\n"); + + // Make sure TLS is required by the server due to SessionTlsMode::Required + stream2.write_all(b"MAIL FROM:\r\n").await.unwrap(); + let r = stream2.read(&mut buf).await.unwrap(); + assert_eq!( + &buf[..r], + b"503 5.5.1 TLS is required to submit mail on this server.\r\n" + ); + + // Fire up TLS handshake + stream2.write_all(b"STARTTLS\r\n").await.unwrap(); + let r = stream2.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"220 2.0.0 Ready to start TLS.\r\n"); + + let rustls_client_cfg = ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(NoopServerCertVerifier::default())) + .with_no_client_auth(); + let tls_connector = TlsConnector::from(Arc::new(rustls_client_cfg)); + let mut tls_stream = tls_connector + .connect(ServerName::try_from("foo").unwrap(), stream2) + .await + .unwrap(); + + // Make sure there's no 250-STARTTLS in EHLO anymore inside TLS session + tls_stream.write_all(b"EHLO foo.bar\r\n").await.unwrap(); + let r = tls_stream.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"250-test you had me at EHLO\r\n250-SMTPUTF8\r\n250-ENHANCEDSTATUSCODES\r\n250-CHUNKING\r\n250 8BITMIME\r\n"); + + // No TLS-in-TLS allowed + tls_stream.write_all(b"STARTTLS\r\n").await.unwrap(); + let r = tls_stream.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"504 5.7.4 Already in TLS mode.\r\n"); + + // Now MAIL FROM should work + tls_stream.write_all(b"MAIL FROM:\r\n").await.unwrap(); + let r = tls_stream.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"250 2.1.0 OK\r\n"); + + // All good + tls_stream.write_all(b"QUIT\r\n").await.unwrap(); + let r = tls_stream.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"221 2.0.0 Bye.\r\n"); + } +} diff --git a/ic-bn-lib/src/smtp/inbound/rcpt_to.rs b/ic-bn-lib/src/smtp/inbound/rcpt_to.rs new file mode 100644 index 0000000..731a4d1 --- /dev/null +++ b/ic-bn-lib/src/smtp/inbound/rcpt_to.rs @@ -0,0 +1,84 @@ +use std::{borrow::Cow, str::FromStr}; + +use smtp_proto::{ + RCPT_NOTIFY_DELAY, RCPT_NOTIFY_FAILURE, RCPT_NOTIFY_NEVER, RCPT_NOTIFY_SUCCESS, RcptTo, +}; + +use crate::{ + network::AsyncReadWrite, + smtp::{ + RecipientPolicy, RecipientResolveError, + address::EmailAddress, + inbound::{Session, SessionResult}, + }, +}; + +impl Session { + /// Handles RCPT TO command + pub async fn handle_rcpt_to(&mut self, to: RcptTo>) -> SessionResult<()> { + if self.data.mail_from.is_none() { + return self + .reply("503", "5.5.1", "MAIL FROM is required first.") + .await; + } + + // Check if DSN-related stuff was requested + if (to.flags + & (RCPT_NOTIFY_DELAY | RCPT_NOTIFY_NEVER | RCPT_NOTIFY_SUCCESS | RCPT_NOTIFY_FAILURE)) + != 0 + || to.orcpt.is_some() + { + return self.ext_unsupported("DSN").await; + } + + let Ok(address) = EmailAddress::from_str(&to.address) else { + return self.reply("550", "5.1.2", "Incorrect address.").await; + }; + + if self.data.rcpt_to.contains(&address) { + return self.reply("250", "2.1.5", "OK").await; + } + + if self.data.rcpt_to.len() >= self.cfg.max_recipients { + return self.reply("455", "4.5.3", "Too many recipients.").await; + } + + match self + .cfg + .recipient_resolver + .resolve_recipient(&address) + .await + { + Ok(v) => match v { + RecipientPolicy::Accept => { + self.data.rcpt_to.push(address); + } + RecipientPolicy::Rewrite(new_address) => { + self.data.rcpt_to.push(new_address); + } + RecipientPolicy::Expand(additional_addresses) => { + self.data.rcpt_to.push(address); + self.data.rcpt_to.extend(additional_addresses); + } + }, + + Err(e) => match e { + RecipientResolveError::UnknownDomain => { + return self + .reply("550", "5.1.2", "Unknown recipient domain.") + .await; + } + RecipientResolveError::UnknownRecipient => { + return self.reply("550", "5.1.2", "Mailbox does not exist.").await; + } + RecipientResolveError::Other(_) => { + return self + .reply("451", "4.4.3", "Unable to verify address at this time.") + .await; + } + }, + } + + self.reply("250", "2.1.5", "OK").await + } +} diff --git a/ic-bn-lib/src/smtp/inbound/session.rs b/ic-bn-lib/src/smtp/inbound/session.rs new file mode 100644 index 0000000..55ae28f --- /dev/null +++ b/ic-bn-lib/src/smtp/inbound/session.rs @@ -0,0 +1,459 @@ +use std::{ + borrow::Cow, + io::Write, + time::{Duration, Instant}, +}; + +use anyhow::Context; +use smtp_proto::{ + Error as SmtpError, Request, + request::receiver::{BdatReceiver, DataReceiver, DummyDataReceiver, DummyLineReceiver}, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + select, +}; +use tokio_util::{sync::CancellationToken, time::FutureExt}; +use tracing::debug; +use uuid::Uuid; + +use crate::{ + network::AsyncReadWrite, + smtp::{ + DeliveryError, Message, + inbound::{Session, SessionError, SessionResult, SessionState, SessionUpgrade}, + }, +}; + +const MAX_REPLY_LEN: usize = 256; + +#[allow(clippy::too_many_arguments)] +impl Session { + /// Writes given bytes to the session & flushes the buffer + pub async fn write(&mut self, bytes: &[u8]) -> SessionResult<()> { + debug!("{self}: Writing: {}", String::from_utf8_lossy(bytes)); + self.stream.write_all(bytes).await?; + self.stream.flush().await?; + Ok(()) + } + + /// Replies with given codes & message. + /// + /// It accepts replies up to `MAX_REPLY_LEN` long since it uses + /// a stack array to avoid heap allocation for performance reasons. + /// If ever this module would need more - increase the constant. + pub(crate) async fn reply(&mut self, code: &str, ext: &str, msg: &str) -> SessionResult<()> { + let len = code.len() + ext.len() + msg.len() + 4; + assert!( + len <= MAX_REPLY_LEN, + "Reply longer than supported - increase MAX_REPLY_LEN" + ); + + let mut buf = [0; MAX_REPLY_LEN]; + write!(&mut buf[..], "{code} {ext} {msg}\r\n")?; + self.write(&buf[..len]).await + } + + pub(crate) async fn ext_unsupported(&mut self, ext: &str) -> SessionResult<()> { + self.reply( + "501", + "5.5.4", + &format!("{ext} extension is not supported."), + ) + .await + } + + pub(crate) async fn message_too_big(&mut self) -> SessionResult<()> { + let msg = format!( + "Message too big for, we accept up to {} bytes.", + self.cfg.max_message_size + ); + return self.reply("552", "5.3.4", &msg).await; + } + + /// Sends greeting message + async fn greeting(&mut self) -> SessionResult<()> { + // If we have HELO delay configured - try to read from the stream for up to this duration. + // The client needs to wait silently until we send our greeting. + // If something comes in - then the client isn't respecting the protocol, + // we consider him malicious and drop the connection. + if let Some(v) = self.cfg.helo_delay { + let mut buf = [0; 256]; + match self.stream.read(&mut buf).timeout(v).await { + Ok(Ok(bytes_read)) => { + if bytes_read > 0 { + self.reply( + "501", + "5.7.1", + "Client sent command before greeting banner.", + ) + .await?; + return Err(SessionError::SendsBeforeGreeting); + } + } + Ok(Err(e)) => return Err(e.into()), + Err(_) => {} + } + } + + self.write(&self.cfg.greeting.clone()).await + } + + async fn handle_error(&mut self, error: SmtpError) -> SessionResult<()> { + let (code, ext, msg) = match error { + SmtpError::UnknownCommand | SmtpError::InvalidResponse { .. } => { + ("500", "5.5.1", "Invalid command.".to_string()) + } + SmtpError::InvalidSenderAddress => { + ("501", "5.1.8", "Bad sender's system address.".to_string()) + } + SmtpError::InvalidRecipientAddress => ( + "501", + "5.1.3", + "Bad destination mailbox address syntax.".to_string(), + ), + SmtpError::SyntaxError { syntax } => { + ("501", "5.5.2", format!("Syntax error, expected: {syntax}")) + } + SmtpError::InvalidParameter { param } => { + ("501", "5.5.4", format!("Invalid parameter {param:?}.")) + } + SmtpError::UnsupportedParameter { param } => { + ("504", "5.5.4", format!("Unsupported parameter {param:?}.")) + } + // These are handled one level above + SmtpError::ResponseTooLong | SmtpError::NeedsMoreData { .. } => unreachable!(), + }; + + self.counters.errors += 1; + self.reply(code, ext, &msg).await + } + + async fn handle_request(&mut self, req: Request>) -> SessionResult<()> { + match req { + Request::Ehlo { host } => { + self.handle_ehlo(&host, true).await?; + } + Request::Helo { host } => { + self.handle_ehlo(&host, false).await?; + } + Request::Mail { from } => { + self.handle_mail_from(from).await?; + } + Request::Rcpt { to } => { + self.handle_rcpt_to(to).await?; + } + Request::Rset => { + self.reset_message(); + self.reply("250", "2.0.0", "OK").await?; + } + Request::Quit => { + self.reply("221", "2.0.0", "Bye.").await?; + return Err(SessionError::Quit); + } + Request::Noop { .. } => { + self.reply("250", "2.0.0", "OK").await?; + } + _ => { + self.reply("502", "5.5.1", "Command not implemented.") + .await?; + self.counters.errors += 1; + } + } + + Ok(()) + } + + /// Main SMTP state machine + async fn ingest(&mut self, bytes: &[u8]) -> SessionResult { + debug!("{self}: Read: {}", String::from_utf8_lossy(bytes)); + + // Check if we are over session transfer quota + if self.counters.bytes_ingested + bytes.len() > self.cfg.max_session_data { + self.reply("452", "4.7.28", "Session transfer quota exceeded.") + .await?; + return Err(SessionError::TransferQuotaExceeded( + self.cfg.max_session_data, + )); + } + + // Check if we are over session time quota + if Instant::now() > self.counters.valid_until { + self.reply("452", "4.3.2", "Session open for too long.") + .await?; + return Err(SessionError::TtlExceeded( + self.cfg.max_session_duration.as_secs(), + )); + } + + // Check if we are over error limit + if self.counters.errors > self.cfg.max_errors { + self.reply("452", "4.3.2", "Too many errors.").await?; + return Err(SessionError::TooManyErrors); + } + + self.counters.bytes_ingested += bytes.len(); + let mut iter = bytes.iter(); + // We can't take mutable ref to self.state & self at the same time, + // so we extract state temporarily. + let mut state = std::mem::replace(&mut self.state, SessionState::None); + + loop { + match &mut state { + SessionState::Greeting => { + // This is handled separately + unreachable!(); + } + SessionState::Request(rx) => { + match rx.ingest(&mut iter) { + Ok(request) => match request { + // ASCII data + Request::Data => { + if self.can_accept_message().await? { + self.write(b"354 Start mail input; end with .\r\n") + .await?; + self.data.message = Vec::with_capacity(1024); + state = SessionState::Data(DataReceiver::new()); + continue; + } + } + // Binary data + Request::Bdat { + chunk_size, + is_last, + } => { + // Check if we will be past max message limit with this chunk + state = if self.data.message.len() + chunk_size + > self.cfg.max_message_size + { + SessionState::DataTooLarge(DummyDataReceiver::new_bdat( + chunk_size, + )) + } else { + // Preallocate the needed capacity for the chunk if need be + let free = + self.data.message.capacity() - self.data.message.len(); + if free < chunk_size { + self.data.message.reserve(chunk_size - free); + } + + SessionState::Bdat(BdatReceiver::new(chunk_size, is_last)) + } + } + Request::StartTls => { + if self.tls_info.is_some() { + self.reply("504", "5.7.4", "Already in TLS mode.").await?; + self.counters.errors += 1; + } else if !self.cfg.tls_mode.enabled() { + self.reply("502", "5.7.0", "TLS not available.").await?; + self.counters.errors += 1; + } else { + self.reply("220", "2.0.0", "Ready to start TLS.").await?; + self.state = state; + return Ok(SessionUpgrade::StartTls); + } + } + other_request => { + self.handle_request(other_request).await?; + } + }, + Err(SmtpError::ResponseTooLong) => { + state = SessionState::RequestTooLarge(DummyLineReceiver::default()); + continue; + } + // In case of NeedsMoreData error we just leave + // and wait for new data to be ingested + Err(SmtpError::NeedsMoreData { .. }) => break, + + // Handle other errors separately + Err(e) => { + self.handle_error(e).await?; + } + } + } + SessionState::Data(rx) => { + // Check if the message already exceeds allowed size + if self.data.message.len() + bytes.len() > self.cfg.max_message_size { + state = SessionState::DataTooLarge(DummyDataReceiver::new_data(rx)); + continue; + } else if rx.ingest(&mut iter, &mut self.data.message) { + // The message is fully received, time to queue + self.queue_message().await?; + state = SessionState::default(); + } else { + // No end-of-message marker found yet + break; + } + } + SessionState::Bdat(rx) => { + if rx.ingest(&mut iter, &mut self.data.message) { + if self.can_accept_message().await? { + if rx.is_last { + self.queue_message().await?; + } else { + self.reply("250", "2.6.0", "Chunk accepted.").await?; + } + } else { + self.data.message = Vec::with_capacity(0); + } + state = SessionState::default(); + } else { + // Still some bytes left in the chunk + break; + } + } + SessionState::RequestTooLarge(rx) => { + // If line-feed found - issue error, otherwise keep ingesting + if rx.ingest(&mut iter) { + self.reply("554", "5.3.4", "Line is too long.").await?; + state = SessionState::default(); + self.counters.errors += 1; + } else { + // No line-feed found yet + break; + } + } + SessionState::DataTooLarge(rx) => { + // If end-of-message marker found - issue error, otherwise keep ingesting + if rx.ingest(&mut iter) { + self.message_too_big().await?; + state = SessionState::default(); + self.counters.errors += 1; + } else { + // No end-of-message marker found yet + break; + } + } + SessionState::None => unreachable!(), + } + } + self.state = state; + + Ok(SessionUpgrade::No) + } + + /// Drives the session forward + pub async fn handle( + &mut self, + shutdown_token: CancellationToken, + ) -> SessionResult { + let mut buf = vec![0; 8192]; + + if matches!(self.state, SessionState::Greeting) { + self.greeting().await?; + self.state = SessionState::default(); + } + + loop { + select! { + // Read from the client with a timeout + res = self.stream.read(&mut buf).timeout(self.cfg.timeout) => { + match res { + Ok(Ok(bytes_read)) => { + let upgrade = self.ingest(&buf[..bytes_read]).await?; + if matches!(upgrade, SessionUpgrade::StartTls) { + return Ok(upgrade); + } + } + Ok(Err(e)) => { + return Err(e.into()); + } + Err(_) => { + self.reply("221", "2.0.0", "Disconnecting due to inactivity.").await?; + return Err(SessionError::Timeout); + } + } + }, + + () = shutdown_token.cancelled() => { + break; + } + } + } + + Ok(SessionUpgrade::No) + } + + async fn queue_message(&mut self) -> SessionResult<()> { + #[cfg(not(test))] + let id = Uuid::now_v7(); + #[cfg(test)] + let id = Uuid::nil(); + + let message_size = self.data.message.len(); + + // SAFETY: Code makes sure these are all Some(). + // It's better to panic in tests if they are not. + let message = Message { + id, + ehlo_hostname: self.data.ehlo_hostname.clone().unwrap(), + mail_from: self.data.mail_from.take().unwrap(), + rcpt_to: std::mem::take(&mut self.data.rcpt_to), + body: std::mem::take(&mut self.data.message), + }; + + if let Err(e) = self.cfg.delivery_agent.deliver_mail(message).await { + let (code, ext, msg) = match e { + DeliveryError::Permanent(v) => { + ("550", "5.5.0", format!("Permanent delivery error: {v}")) + } + DeliveryError::Temporary(v) => { + ("450", "4.5.0", format!("Temporary delivery error: {v}")) + } + }; + + self.reply(code, ext, &msg).await?; + self.reset_message(); + return Ok(()); + } + + self.reply( + "250", + "2.0.0", + &format!("Message ({message_size} bytes) queued with id {id}"), + ) + .await?; + + self.counters.messages_queued += 1; + self.reset_message(); + Ok(()) + } + + async fn can_accept_message(&mut self) -> SessionResult { + if self.counters.messages_queued >= self.cfg.max_messages_per_session { + self.reply( + "452", + "4.4.5", + "Maximum number of messages per session exceeded.", + ) + .await?; + return Err(SessionError::TooManyMessagesPerSession); + } else if self.data.rcpt_to.is_empty() { + self.reply("503", "5.5.1", "RCPT TO is required first.") + .await?; + self.counters.errors += 1; + return Ok(false); + } + + Ok(true) + } + + /// Resets the message-related fields to their initial state + pub(crate) fn reset_message(&mut self) { + self.data.mail_from = None; + self.data.rcpt_to.clear(); + self.data.message.clear(); + } + + /// Closes the connection + pub async fn shutdown(&mut self) -> SessionResult<()> { + self.stream + .shutdown() + .timeout(Duration::from_secs(10)) + .await + .context("shutdown timed out")? + .context("shutdown failed")?; + + Ok(()) + } +} diff --git a/ic-bn-lib/src/smtp/mod.rs b/ic-bn-lib/src/smtp/mod.rs new file mode 100644 index 0000000..24756a8 --- /dev/null +++ b/ic-bn-lib/src/smtp/mod.rs @@ -0,0 +1,107 @@ +use std::fmt::{Debug, Display}; + +use async_trait::async_trait; +use fqdn::FQDN; +use itertools::Itertools; +use tracing::warn; +use uuid::Uuid; + +use crate::smtp::address::EmailAddress; + +pub mod address; +pub mod ic; +pub mod inbound; +pub mod server; + +/// Recipient resolution policy +pub enum RecipientPolicy { + Accept, + Rewrite(EmailAddress), + Expand(Vec), +} + +/// Recipient resolution error +#[derive(thiserror::Error, Debug)] +pub enum RecipientResolveError { + #[error("Unknown recipient")] + UnknownRecipient, + #[error("Unknown domain")] + UnknownDomain, + #[error("{0}")] + Other(String), +} + +/// Delivery error +#[derive(thiserror::Error, Clone, Debug)] +pub enum DeliveryError { + #[error("{0}")] + Temporary(String), + #[error("{0}")] + Permanent(String), +} + +/// Low-level E-Mail representation +#[derive(Debug, Eq, PartialEq, Hash)] +pub struct Message { + pub id: Uuid, + pub ehlo_hostname: FQDN, + pub mail_from: EmailAddress, + pub rcpt_to: Vec, + pub body: Vec, +} + +impl Display for Message { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "id: {}, ehlo: {}, from: {}, to: {}, msg: {}", + self.id, + self.ehlo_hostname, + self.mail_from, + self.rcpt_to.iter().map(|x| x.to_string()).join(", "), + String::from_utf8_lossy(&self.body) + .replace('\n', "\\n") + .replace('\r', "\\r") + ) + } +} + +/// Looks up the given recipient & applies `RecipientPolicy` policy +#[async_trait] +pub trait ResolvesRecipient: Send + Sync + Debug { + async fn resolve_recipient( + &self, + rcpt: &EmailAddress, + ) -> Result; +} + +/// Delivers the E-Mail message +#[async_trait] +pub trait DeliversMail: Send + Sync + Debug { + async fn deliver_mail(&self, message: Message) -> Result<(), DeliveryError>; +} + +#[derive(Debug)] +pub struct DummyRecipientResolver; + +#[async_trait] +impl ResolvesRecipient for DummyRecipientResolver { + async fn resolve_recipient( + &self, + rcpt: &EmailAddress, + ) -> Result { + warn!("DummyRecipientResolver: {rcpt}"); + Ok(RecipientPolicy::Accept) + } +} + +#[derive(Debug)] +pub struct DummyDeliveryAgent; + +#[async_trait] +impl DeliversMail for DummyDeliveryAgent { + async fn deliver_mail(&self, message: Message) -> Result<(), DeliveryError> { + warn!("DummyDeliveryAgent: {message}"); + Ok(()) + } +} diff --git a/ic-bn-lib/src/smtp/server.rs b/ic-bn-lib/src/smtp/server.rs new file mode 100644 index 0000000..0e6d627 --- /dev/null +++ b/ic-bn-lib/src/smtp/server.rs @@ -0,0 +1,101 @@ +use std::{fmt::Display, io, net::SocketAddr, sync::Arc, time::Duration}; + +use async_trait::async_trait; +use ic_bn_lib_common::{traits::Run, types::http::ListenerOpts}; +use tokio::{ + net::{TcpListener, TcpStream}, + select, +}; +use tokio_util::{sync::CancellationToken, task::TaskTracker, time::FutureExt}; +use tracing::{info, warn}; + +use crate::{ + network::listener::listen_tcp, + smtp::inbound::{SessionConfig, SessionError, SessionResult, manager::SessionManager}, +}; + +/// Listens for new SMTP connections and creates sessions +pub struct Server { + listen_addr: SocketAddr, + listener: TcpListener, + params: Arc, + tracker: TaskTracker, +} + +impl Display for Server { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SMTP/Server({})", self.listen_addr) + } +} + +impl Server { + /// Creates a new `Server` to listen on `listen_addr` + pub fn new(listen_addr: SocketAddr, cfg: SessionConfig) -> io::Result { + let listener = listen_tcp(listen_addr, ListenerOpts::default())?; + Self::new_with_listener(listener, cfg) + } + + /// Creates a new `Server` from a pre-built `TcpListener` + pub fn new_with_listener(listener: TcpListener, params: SessionConfig) -> io::Result { + Ok(Self { + listen_addr: listener.local_addr()?, + listener, + params: Arc::new(params), + tracker: TaskTracker::new(), + }) + } + + async fn handle_connection( + &self, + res: io::Result<(TcpStream, SocketAddr)>, + token: &CancellationToken, + ) { + match res { + Ok((stream, addr)) => { + info!("{self}: New connection from {addr}"); + + let (params, token) = (self.params.clone(), token.child_token()); + self.tracker.spawn(SessionManager::handle_connection( + stream, addr, params, token, + )); + } + + Err(e) => { + warn!("{self}: Unable to accept connection: {e:#}"); + tokio::time::sleep(Duration::from_millis(50)).await; + } + } + } + + /// Main connection handling loop + pub async fn serve(&self, token: CancellationToken) -> io::Result<()> { + loop { + select! { + res = self.listener.accept() => { + self.handle_connection(res, &token).await; + } + + () = token.cancelled() => { + warn!("{self}: Server shutting down, closing connections"); + + self.tracker.close(); + if self.tracker.wait().timeout(Duration::from_secs(30)).await.is_err() { + warn!("{self}: Timed out waiting for connections to close"); + } + + break; + } + } + } + + Ok(()) + } +} + +#[async_trait] +impl Run for Server { + async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> { + self.serve(token).await?; + Ok(()) + } +} diff --git a/ic-bn-lib/tools/smtp_server.rs b/ic-bn-lib/tools/smtp_server.rs new file mode 100644 index 0000000..c39de8f --- /dev/null +++ b/ic-bn-lib/tools/smtp_server.rs @@ -0,0 +1,34 @@ +use std::{net::SocketAddr, str::FromStr, sync::Arc, time::Duration}; + +use ic_bn_lib::{ + smtp::{inbound::SessionConfig, server::Server}, + tests::{TEST_CERT_1, TEST_KEY_2}, + tls::resolver::StubResolver, +}; +use rustls::ServerConfig; +use tokio_util::sync::CancellationToken; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt().init(); + rustls::crypto::ring::default_provider() + .install_default() + .unwrap(); + + let rustls_server_cfg = ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(Arc::new( + StubResolver::new(TEST_CERT_1.as_bytes(), TEST_KEY_2.as_bytes()).unwrap(), + )); + + let mut cfg = SessionConfig::new("mail.icp.net"); + //cfg.helo_delay = Some(Duration::from_secs(1)); + //params.max_message_size = 16; + //params.max_session_duration = Duration::from_secs(30); + //params.max_session_data = 16; + cfg.tls_mode = ic_bn_lib::smtp::inbound::SessionTlsMode::Allowed(Arc::new(rustls_server_cfg)); + + let server = Server::new(SocketAddr::from_str("127.0.0.1:1025").unwrap(), cfg).unwrap(); + + server.serve(CancellationToken::new()).await.unwrap(); +}