From cac222595fe761d5fd013cb452c2e7699aea01cf Mon Sep 17 00:00:00 2001 From: Igor Novgorodov Date: Wed, 13 May 2026 15:55:36 +0200 Subject: [PATCH 1/8] Initial SMTP work --- Cargo.toml | 3 + ic-bn-lib/Cargo.toml | 3 + ic-bn-lib/src/http/mod.rs | 79 +------- ic-bn-lib/src/http/server/mod.rs | 149 +-------------- ic-bn-lib/src/lib.rs | 2 + ic-bn-lib/src/network/listener.rs | 100 ++++++++++ ic-bn-lib/src/network/mod.rs | 100 ++++++++++ ic-bn-lib/src/smtp/address.rs | 71 +++++++ ic-bn-lib/src/smtp/ic/candid.rs | 57 ++++++ ic-bn-lib/src/smtp/ic/mod.rs | 1 + ic-bn-lib/src/smtp/ic/smtp.rs | 294 +++++++++++++++++++++++++++++ ic-bn-lib/src/smtp/mod.rs | 3 + ic-bn-lib/src/smtp/session/helo.rs | 33 ++++ ic-bn-lib/src/smtp/session/mod.rs | 98 ++++++++++ 14 files changed, 775 insertions(+), 218 deletions(-) create mode 100644 ic-bn-lib/src/network/listener.rs create mode 100644 ic-bn-lib/src/network/mod.rs create mode 100644 ic-bn-lib/src/smtp/address.rs create mode 100644 ic-bn-lib/src/smtp/ic/candid.rs create mode 100644 ic-bn-lib/src/smtp/ic/mod.rs create mode 100644 ic-bn-lib/src/smtp/ic/smtp.rs create mode 100644 ic-bn-lib/src/smtp/mod.rs create mode 100644 ic-bn-lib/src/smtp/session/helo.rs create mode 100644 ic-bn-lib/src/smtp/session/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 3083046..67d3fb2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,8 @@ instant-acme = { version = "0.8.2", default-features = false, features = [ "ring", "hyper-rustls", ] } +mail-auth = "0.8.0" +mail-parser = { version = "0.11.3", features = ["full_encoding"] } mockall = "0.13.0" parse-size = { version = "1.1.0", features = ["std"] } prometheus = "0.14.0" @@ -64,6 +66,7 @@ rustls = { version = "0.23.18", default-features = false, features = [ "brotli", ] } serde = { version = "1.0.214", features = ["derive"] } +smtp-proto = "0.2.1" socket2 = "0.6.0" strum = { version = "0.28", features = ["derive"] } strum_macros = "0.28" diff --git a/ic-bn-lib/Cargo.toml b/ic-bn-lib/Cargo.toml index a69c2f4..30eca4d 100644 --- a/ic-bn-lib/Cargo.toml +++ b/ic-bn-lib/Cargo.toml @@ -74,6 +74,8 @@ ic-bn-lib-common = "0.1" indoc = "2.0.6" instant-acme = { workspace = true, optional = true } itertools = "0.14.0" +mail-auth = { workspace = true } +mail-parser = { workspace = true } moka = { version = "0.12.15", features = ["sync", "future"] } nix = { version = "0.30.0", features = ["signal"] } ppp = "2.3.0" @@ -107,6 +109,7 @@ sev = { version = "7.1.0", optional = true, default-features = false, features = ] } sha1 = "0.11.0" sha2 = { version = "0.11.0", optional = true } +smtp-proto = { workspace = true } socket2 = "0.6.0" strum = { workspace = true } strum_macros = { workspace = true } diff --git a/ic-bn-lib/src/http/mod.rs b/ic-bn-lib/src/http/mod.rs index 84b8548..476a357 100644 --- a/ic-bn-lib/src/http/mod.rs +++ b/ic-bn-lib/src/http/mod.rs @@ -8,17 +8,8 @@ pub mod proxy; pub mod server; pub mod shed; -use std::{ - io, - pin::{Pin, pin}, - sync::{Arc, atomic::Ordering}, - task::{Context, Poll}, -}; - use axum::response::{IntoResponse, Redirect}; use http::{HeaderMap, Method, Request, StatusCode, Uri, Version, header::HOST, uri::PathAndQuery}; -use ic_bn_lib_common::types::http::Stats; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; #[cfg(feature = "clients-hyper")] pub use client::clients_hyper::{HyperClient, HyperClientLeastLoaded}; @@ -28,11 +19,7 @@ pub use client::clients_reqwest::{ pub use server::{Server, ServerBuilder}; use url::Url; -use crate::http::headers::X_FORWARDED_HOST; - -/// Blanket async read+write trait for streams Box-ing -trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {} -impl AsyncReadWrite for T {} +use crate::{http::headers::X_FORWARDED_HOST, network::AsyncReadWrite}; /// Calculate very approximate HTTP request/response headers size in bytes. /// More or less accurate only for http/1.1 since in h2 headers are in HPACK-compressed. @@ -101,70 +88,6 @@ pub fn extract_authority(request: &Request) -> Option<&str> { .and_then(extract_host) } -/// Async read+write wrapper that counts bytes read/written -struct AsyncCounter { - inner: T, - stats: Arc, -} - -impl AsyncCounter { - /// Create new `AsyncCounter` - pub fn new(inner: T) -> (Self, Arc) { - let stats = Arc::new(Stats::new()); - - ( - Self { - inner, - stats: stats.clone(), - }, - stats, - ) - } -} - -impl AsyncRead for AsyncCounter { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - let size_before = buf.filled().len(); - let poll = pin!(&mut self.inner).poll_read(cx, buf); - if matches!(&poll, Poll::Ready(Ok(()))) { - let rcvd = buf.filled().len() - size_before; - self.stats.rcvd.fetch_add(rcvd as u64, Ordering::SeqCst); - } - - poll - } -} - -impl AsyncWrite for AsyncCounter { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let poll = pin!(&mut self.inner).poll_write(cx, buf); - if let Poll::Ready(Ok(v)) = &poll { - self.stats.sent.fetch_add(*v as u64, Ordering::SeqCst); - } - - poll - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - pin!(&mut self.inner).poll_shutdown(cx) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - pin!(&mut self.inner).poll_flush(cx) - } -} - /// Error that might happen during Url to Uri conversion #[derive(thiserror::Error, Debug)] pub enum UrlToUriError { diff --git a/ic-bn-lib/src/http/server/mod.rs b/ic-bn-lib/src/http/server/mod.rs index ef8056f..7e89d9d 100644 --- a/ic-bn-lib/src/http/server/mod.rs +++ b/ic-bn-lib/src/http/server/mod.rs @@ -2,9 +2,7 @@ pub mod proxy_protocol; use std::{ fmt::Display, - io, net::SocketAddr, - os::unix::fs::PermissionsExt, path::PathBuf, sync::{ Arc, @@ -39,10 +37,8 @@ use prometheus::{ use proxy_protocol::{ProxyHeader, ProxyProtocolStream}; use rustls::sign::SingleCertAndKey; use scopeguard::defer; -use socket2::{Domain, Socket, Type}; use tokio::{ io::AsyncWriteExt, - net::{TcpListener, UnixListener, UnixSocket}, pin, select, sync::mpsc::channel, time::{sleep, timeout}, @@ -54,92 +50,20 @@ use tower_service::Service; use tracing::{debug, info, warn}; use uuid::Uuid; -use super::{AsyncCounter, AsyncReadWrite, body::NotifyingBody}; -use crate::tls::{pem_convert_to_rustls, prepare_server_config}; +use super::body::NotifyingBody; +use crate::{ + network::{AsyncCounter, AsyncReadWrite, listener::Listener, tls_handshake}, + tls::{pem_convert_to_rustls, prepare_server_config}, +}; const YEAR: Duration = Duration::from_secs(86400 * 365); -/// Connection listener -pub enum Listener { - Tcp(TcpListener), - Unix(UnixListener), -} - -impl Listener { - /// Create a new Listener - pub fn new(addr: Addr, opts: ListenerOpts) -> Result { - Ok(match addr { - Addr::Tcp(v) => Self::Tcp(listen_tcp(v, opts)?), - Addr::Unix(v) => Self::Unix(listen_unix(v, opts)?), - }) - } - - /// Accept the connection - async fn accept(&self) -> Result<(Box, Addr), io::Error> { - Ok(match self { - Self::Tcp(v) => { - let x = v.accept().await?; - (Box::new(x.0), Addr::Tcp(x.1)) - } - Self::Unix(v) => { - let x = v.accept().await?; - ( - Box::new(x.0), - Addr::Unix(x.1.as_pathname().map(|x| x.into()).unwrap_or_default()), - ) - } - }) - } - - pub fn local_addr(&self) -> Option { - match &self { - Self::Tcp(v) => v.local_addr().ok(), - Self::Unix(_) => None, - } - } -} - -impl From for Listener { - /// Creates a Listener from TcpListener - fn from(v: TcpListener) -> Self { - Self::Tcp(v) - } -} - -impl From for Listener { - /// Creates a Listener from UnixListener - fn from(v: UnixListener) -> Self { - Self::Unix(v) - } -} - #[derive(Clone)] enum RequestState { Start, End, } -async fn tls_handshake( - rustls_cfg: Arc, - stream: impl AsyncReadWrite, -) -> Result<(impl AsyncReadWrite, TlsInfo), Error> { - let tls_acceptor = TlsAcceptor::from(rustls_cfg); - - // Perform the TLS handshake - let start = Instant::now(); - let stream = tls_acceptor - .accept(stream) - .await - .context("TLS accept failed")?; - let duration = start.elapsed(); - - let conn = stream.get_ref().1; - let mut tls_info = TlsInfo::try_from(conn)?; - tls_info.handshake_dur = duration; - - Ok((stream, tls_info)) -} - struct Conn { addr: Addr, remote_addr: Addr, @@ -638,7 +562,8 @@ impl Server { keepalive: (&self.options).into(), }; - let listener = Listener::new(self.addr.clone(), opts)?; + let listener = + Listener::new(self.addr.clone(), opts).context("unable to create listener")?; self.serve_with_listener(listener, token).await } @@ -733,64 +658,6 @@ impl Server { } } -/// Creates a TCP listener with given opts -pub fn listen_tcp(addr: SocketAddr, opts: ListenerOpts) -> Result { - let domain = if addr.is_ipv4() { - Domain::IPV4 - } else { - Domain::IPV6 - }; - - let socket = Socket::new(domain, Type::STREAM, None).context("unable to create socket")?; - socket - .set_tcp_nodelay(true) - .context("unable to set TCP_NODELAY")?; - - if let Some(v) = opts.mss { - socket.set_tcp_mss(v).context("unable to set TCP MSS")?; - } - - socket - .set_reuse_address(true) - .context("unable to set SO_REUSEADDR")?; - socket - .set_tcp_keepalive(&opts.keepalive) - .context("unable to set keepalive on the socket")?; - socket - .set_nonblocking(true) - .context("unable to set socket into non-blocking mode")?; - - socket.bind(&addr.into()).context("unable to bind socket")?; - socket - .listen(opts.backlog as i32) - .context("unable to listen on the socket")?; - - let listener = TcpListener::from_std(socket.into()) - .context("unable to convert socket from the standard one")?; - - Ok(listener) -} - -/// Creates a Unix Socket listener with given opts -pub fn listen_unix(path: PathBuf, opts: ListenerOpts) -> Result { - let socket = UnixSocket::new_stream().context("unable to open UNIX socket")?; - - if path.exists() { - std::fs::remove_file(&path).context("unable to remove UNIX socket")?; - } - - socket.bind(&path).context("unable to bind socket")?; - - let socket = socket - .listen(opts.backlog) - .context("unable to listen socket")?; - - std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o666)) - .context("unable to set permissions on socket")?; - - Ok(socket) -} - #[async_trait] impl Run for Server { async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> { @@ -803,6 +670,8 @@ impl Run for Server { mod test { use http::StatusCode; + use crate::network::listener::listen_tcp; + use super::*; #[tokio::test] diff --git a/ic-bn-lib/src/lib.rs b/ic-bn-lib/src/lib.rs index 19e3493..2fd9395 100644 --- a/ic-bn-lib/src/lib.rs +++ b/ic-bn-lib/src/lib.rs @@ -8,7 +8,9 @@ #[cfg(feature = "custom-domains")] pub mod custom_domains; pub mod http; +pub mod network; pub mod pubsub; +pub mod smtp; pub mod tasks; pub mod tests; pub mod tls; diff --git a/ic-bn-lib/src/network/listener.rs b/ic-bn-lib/src/network/listener.rs new file mode 100644 index 0000000..776e980 --- /dev/null +++ b/ic-bn-lib/src/network/listener.rs @@ -0,0 +1,100 @@ +use std::{io, net::SocketAddr, os::unix::fs::PermissionsExt, path::PathBuf}; + +use ic_bn_lib_common::types::http::{Addr, ListenerOpts}; +use socket2::{Domain, Socket, Type}; +use tokio::net::{TcpListener, UnixListener, UnixSocket}; + +use crate::network::AsyncReadWrite; + +/// Generic connection listener +pub enum Listener { + Tcp(TcpListener), + Unix(UnixListener), +} + +impl Listener { + /// Create a new Listener + pub fn new(addr: Addr, opts: ListenerOpts) -> Result { + Ok(match addr { + Addr::Tcp(v) => Self::Tcp(listen_tcp(v, opts)?), + Addr::Unix(v) => Self::Unix(listen_unix(v, opts)?), + }) + } + + /// Accept the connection + pub async fn accept(&self) -> Result<(Box, Addr), io::Error> { + Ok(match self { + Self::Tcp(v) => { + let x = v.accept().await?; + (Box::new(x.0), Addr::Tcp(x.1)) + } + Self::Unix(v) => { + let x = v.accept().await?; + ( + Box::new(x.0), + Addr::Unix(x.1.as_pathname().map(|x| x.into()).unwrap_or_default()), + ) + } + }) + } + + pub fn local_addr(&self) -> Option { + match &self { + Self::Tcp(v) => v.local_addr().ok(), + Self::Unix(_) => None, + } + } +} + +impl From for Listener { + /// Creates a Listener from TcpListener + fn from(v: TcpListener) -> Self { + Self::Tcp(v) + } +} + +impl From for Listener { + /// Creates a Listener from UnixListener + fn from(v: UnixListener) -> Self { + Self::Unix(v) + } +} + +/// Creates a TCP listener with given opts +pub fn listen_tcp(addr: SocketAddr, opts: ListenerOpts) -> Result { + let domain = if addr.is_ipv4() { + Domain::IPV4 + } else { + Domain::IPV6 + }; + + let socket = Socket::new(domain, Type::STREAM, None)?; + socket.set_tcp_nodelay(true)?; + + if let Some(v) = opts.mss { + socket.set_tcp_mss(v)?; + } + + socket.set_reuse_address(true)?; + socket.set_tcp_keepalive(&opts.keepalive)?; + socket.set_nonblocking(true)?; + socket.bind(&addr.into())?; + socket.listen(opts.backlog as i32)?; + + TcpListener::from_std(socket.into()) +} + +/// Creates a Unix Socket listener with given opts +pub fn listen_unix(path: PathBuf, opts: ListenerOpts) -> Result { + let socket = UnixSocket::new_stream()?; + + if path.exists() { + std::fs::remove_file(&path)?; + } + + socket.bind(&path)?; + let socket = socket.listen(opts.backlog)?; + std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o666))?; + + Ok(socket) +} diff --git a/ic-bn-lib/src/network/mod.rs b/ic-bn-lib/src/network/mod.rs new file mode 100644 index 0000000..7f83397 --- /dev/null +++ b/ic-bn-lib/src/network/mod.rs @@ -0,0 +1,100 @@ +use std::{ + io, + pin::{Pin, pin}, + sync::{Arc, atomic::Ordering}, + task::{Context, Poll}, + time::Instant, +}; + +use ic_bn_lib_common::types::http::{Stats, TlsInfo}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::TlsAcceptor; + +pub mod listener; + +/// Blanket async read+write trait for streams Box-ing +pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {} +impl AsyncReadWrite for T {} + +/// Performs TLS handshake on the given stream +pub async fn tls_handshake( + rustls_cfg: Arc, + stream: impl AsyncReadWrite, +) -> Result<(impl AsyncReadWrite, TlsInfo), io::Error> { + let tls_acceptor = TlsAcceptor::from(rustls_cfg); + + // Perform the TLS handshake + let start = Instant::now(); + let stream = tls_acceptor.accept(stream).await?; + let duration = start.elapsed(); + + let conn = stream.get_ref().1; + let mut tls_info = TlsInfo::try_from(conn).map_err(io::Error::other)?; + tls_info.handshake_dur = duration; + + Ok((stream, tls_info)) +} + +/// Async read+write wrapper that counts bytes read/written +pub struct AsyncCounter { + inner: T, + stats: Arc, +} + +impl AsyncCounter { + /// Create new `AsyncCounter` + pub fn new(inner: T) -> (Self, Arc) { + let stats = Arc::new(Stats::new()); + + ( + Self { + inner, + stats: stats.clone(), + }, + stats, + ) + } +} + +impl AsyncRead for AsyncCounter { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let size_before = buf.filled().len(); + let poll = pin!(&mut self.inner).poll_read(cx, buf); + if matches!(&poll, Poll::Ready(Ok(()))) { + let rcvd = buf.filled().len() - size_before; + self.stats.rcvd.fetch_add(rcvd as u64, Ordering::SeqCst); + } + + poll + } +} + +impl AsyncWrite for AsyncCounter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let poll = pin!(&mut self.inner).poll_write(cx, buf); + if let Poll::Ready(Ok(v)) = &poll { + self.stats.sent.fetch_add(*v as u64, Ordering::SeqCst); + } + + poll + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + pin!(&mut self.inner).poll_shutdown(cx) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(&mut self.inner).poll_flush(cx) + } +} diff --git a/ic-bn-lib/src/smtp/address.rs b/ic-bn-lib/src/smtp/address.rs new file mode 100644 index 0000000..893b3ef --- /dev/null +++ b/ic-bn-lib/src/smtp/address.rs @@ -0,0 +1,71 @@ +use std::{fmt::Display, str::FromStr}; + +use derive_new::new; +use fqdn::FQDN; + +#[derive(thiserror::Error, Debug, PartialEq, Eq)] +pub enum EmailAddressError { + #[error("@ is missing")] + AtMissing, + #[error("Domain incorrect: {0}")] + DomainIncorrect(String), +} + +/// E-Mail address representation. +/// +/// Currently we don't validate the local part at all +/// and just consider everything to the right from the +/// rightmost @ as a domain part. +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, new)] +pub struct EmailAddress { + pub local: String, + pub domain: FQDN, +} + +impl Display for EmailAddress { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}@{}", self.local, self.domain) + } +} + +impl FromStr for EmailAddress { + type Err = EmailAddressError; + + fn from_str(s: &str) -> Result { + let (local, domain) = s.rsplit_once('@').ok_or(EmailAddressError::AtMissing)?; + let domain = FQDN::from_ascii_str(domain) + .map_err(|e| EmailAddressError::DomainIncorrect(e.to_string()))?; + + Ok(Self { + local: local.into(), + domain, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_email_address() { + // ok + for v in ["foo@bar", "john.doe@jane.doe", "\"foo+bar@baz\"@dead.beef"] { + assert_eq!(EmailAddress::from_str(v).unwrap().to_string(), v); + } + + // no @ + assert_eq!( + EmailAddress::from_str("foo").unwrap_err(), + EmailAddressError::AtMissing + ); + + // bad domain + for v in ["foo@bar\"baz", "\"jane@doe\""] { + assert!(matches!( + EmailAddress::from_str(v).unwrap_err(), + EmailAddressError::DomainIncorrect(_) + )); + } + } +} diff --git a/ic-bn-lib/src/smtp/ic/candid.rs b/ic-bn-lib/src/smtp/ic/candid.rs new file mode 100644 index 0000000..8ef72c1 --- /dev/null +++ b/ic-bn-lib/src/smtp/ic/candid.rs @@ -0,0 +1,57 @@ +//! Candid types for an IC SMTP Protocol + +use candid::{CandidType, Deserialize}; + +/// Candid `Header`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct Header { + pub name: String, + pub value: String, +} + +/// Candid `Message`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct Message { + pub headers: Vec
, + pub body: Vec, +} + +/// Candid `Address`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct Address { + pub user: String, + pub domain: String, +} + +/// Candid `Envelope`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct Envelope { + pub from: Address, + pub to: Address, +} + +/// Candid `SmtpRequest`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct SmtpRequest { + pub message: Option, + pub envelope: Option, + pub gateway_flags: Option>, +} + +/// Candid `SmtpRequestError` (`code` is `nat64` on the wire in typical canisters). +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct SmtpRequestError { + pub code: u64, + pub message: String, +} + +/// Candid `SmtpResponse` — `Ok` carries an empty record. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub enum SmtpResponse { + Ok(SmtpOk), + Err(SmtpRequestError), +} + +/// Empty record for variant `Ok`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct SmtpOk {} diff --git a/ic-bn-lib/src/smtp/ic/mod.rs b/ic-bn-lib/src/smtp/ic/mod.rs new file mode 100644 index 0000000..938a4a0 --- /dev/null +++ b/ic-bn-lib/src/smtp/ic/mod.rs @@ -0,0 +1 @@ +pub mod candid; diff --git a/ic-bn-lib/src/smtp/ic/smtp.rs b/ic-bn-lib/src/smtp/ic/smtp.rs new file mode 100644 index 0000000..678006c --- /dev/null +++ b/ic-bn-lib/src/smtp/ic/smtp.rs @@ -0,0 +1,294 @@ +//! Candid types and submit logic for the SMTP gateway ↔ canister protocol. + +use candid::{CandidType, Decode, Deserialize, Encode, Principal}; +use ic_bn_lib::ic_agent::{Agent, AgentError}; +use tracing::{debug, warn}; + +use crate::smtp::ReceivedMail; + +/// Candid `Header`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct Header { + pub name: String, + pub value: String, +} + +/// Candid `Message`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct Message { + pub headers: Vec
, + pub body: Vec, +} + +/// Candid `Address`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct Address { + pub user: String, + pub domain: String, +} + +/// Candid `Envelope`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct Envelope { + pub from: Address, + pub to: Address, +} + +/// Candid `SmtpRequest`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct SmtpRequest { + pub message: Option, + pub envelope: Option, + pub gateway_flags: Option>, +} + +/// Candid `SmtpRequestError` (`code` is `nat64` on the wire in typical canisters). +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct SmtpRequestError { + pub code: u64, + pub message: String, +} + +/// Candid `SmtpResponse` — `Ok` carries an empty record. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub enum SmtpResponse { + Ok(SmtpOk), + Err(SmtpRequestError), +} + +/// Empty record for variant `Ok`. +#[derive(Clone, Debug, CandidType, Deserialize)] +pub struct SmtpOk {} + +fn address_from_smtp_path(path: &str) -> Address { + let path = path.trim(); + if path.is_empty() { + return Address { + user: String::new(), + domain: String::new(), + }; + } + match path.rsplit_once('@') { + Some((user, domain)) => Address { + user: user.to_string(), + domain: domain.to_string(), + }, + None => Address { + user: path.to_string(), + domain: String::new(), + }, + } +} + +/// Split RFC 5322 message into header block and body; parse headers with line unfolding. +pub fn parse_rfc5322_message(raw: &[u8]) -> Result { + let (header_end, body_start) = raw + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map(|i| (i, i + 4)) + .or_else(|| { + raw.windows(2) + .position(|w| w == b"\n\n") + .map(|i| (i, i + 2)) + }) + .ok_or_else(|| "message has no header/body separator".to_string())?; + + let headers_src = std::str::from_utf8(&raw[..header_end]) + .map_err(|_| "message headers are not valid UTF-8".to_string())?; + let headers = parse_headers_unfolded(headers_src)?; + let body = raw[body_start..].to_vec(); + Ok(Message { headers, body }) +} + +fn parse_headers_unfolded(block: &str) -> Result, String> { + let lines = unfold_header_block(block); + let mut out = Vec::new(); + for line in lines { + let line = line.trim_end_matches(['\r', '\n']); + if line.is_empty() { + continue; + } + let Some((name, value)) = line.split_once(':') else { + return Err(format!("bad header line: {line:?}")); + }; + let name = name.trim().to_string(); + if name.is_empty() { + return Err("empty header name".to_string()); + } + let value = value.trim_start_matches([' ', '\t']).to_string(); + out.push(Header { name, value }); + } + Ok(out) +} + +/// RFC 5322 unfolding: continuation lines start with WSP. +fn unfold_header_block(block: &str) -> Vec { + let mut merged: Vec = Vec::new(); + for raw_line in block.split('\n') { + let line = raw_line.trim_end_matches('\r'); + let first = line.chars().next(); + let is_continuation = matches!(first, Some(' ' | '\t')); + if is_continuation && !merged.is_empty() { + let last = merged.last_mut().expect("merged non-empty"); + last.push(' '); + last.push_str(line.trim_start_matches([' ', '\t'])); + } else { + merged.push(line.to_string()); + } + } + merged +} + +/// Map canister SMTP-style error codes to an SMTP text reply (code + message for the client). +pub fn smtp_line_from_canister_err(e: &SmtpRequestError) -> (u16, String) { + let c = e.code; + let code = if (400..600).contains(&c) { + c as u16 + } else if c < 400 { + 451 + } else { + 554 + }; + (code, e.message.clone()) +} + +fn agent_err_to_string(e: AgentError) -> String { + e.to_string() +} + +/// Failure from [`submit_mail`]: transport/parse errors or a canister rejection for one recipient. +#[derive(Debug)] +pub enum SubmitMailError { + Other(String), + Rejected { + code: u16, + message: String, + failed_recipient: String, + }, +} + +impl SubmitMailError { + /// SMTP session reply text: `" "` for rejections (matches [`crate::smtp::session::handler_error_to_response`]). + pub fn into_handler_error(self) -> String { + match self { + SubmitMailError::Other(s) => s, + SubmitMailError::Rejected { code, message, .. } => format!("{code} {message}"), + } + } +} + +impl From for SubmitMailError { + fn from(s: String) -> Self { + SubmitMailError::Other(s) + } +} + +/// Submit mail: optional `smtp_request_validate` (query) per recipient, then `smtp_request` (update). +pub async fn submit_mail( + agent: &Agent, + canister_id: Principal, + mail: &ReceivedMail, + gateway_flags: &[String], + validate_before_update: bool, +) -> Result<(), SubmitMailError> { + if mail.rcpt_to.is_empty() { + return Err(SubmitMailError::Other( + "internal error: no recipients".to_string(), + )); + } + + let message = parse_rfc5322_message(&mail.raw_message).map_err(SubmitMailError::Other)?; + let from_addr = address_from_smtp_path(&mail.mail_from); + let flags = if gateway_flags.is_empty() { + None + } else { + Some(gateway_flags.to_vec()) + }; + + for to_path in &mail.rcpt_to { + let to_addr = address_from_smtp_path(to_path); + let envelope = Envelope { + from: from_addr.clone(), + to: to_addr, + }; + + if validate_before_update { + let validate_req = SmtpRequest { + message: None, + envelope: Some(envelope.clone()), + gateway_flags: flags.clone(), + }; + let arg = Encode!(&validate_req).map_err(|e| SubmitMailError::Other(e.to_string()))?; + let out = agent + .query(&canister_id, "smtp_request_validate") + .with_arg(arg) + .call() + .await + .map_err(|e| SubmitMailError::Other(agent_err_to_string(e)))?; + let resp = Decode!(&out, SmtpResponse).map_err(|e| SubmitMailError::Other(e.to_string()))?; + match resp { + SmtpResponse::Ok(_) => {} + SmtpResponse::Err(err) => { + let (code, msg) = smtp_line_from_canister_err(&err); + warn!(%canister_id, %to_path, canister_code = %err.code, "smtp_request_validate rejected"); + return Err(SubmitMailError::Rejected { + code, + message: msg, + failed_recipient: to_path.clone(), + }); + } + } + } + + let full = SmtpRequest { + message: Some(message.clone()), + envelope: Some(envelope), + gateway_flags: flags.clone(), + }; + let arg = Encode!(&full).map_err(|e| SubmitMailError::Other(e.to_string()))?; + let out = agent + .update(&canister_id, "smtp_request") + .with_arg(arg) + .call_and_wait() + .await + .map_err(|e| SubmitMailError::Other(agent_err_to_string(e)))?; + let resp = Decode!(&out, SmtpResponse).map_err(|e| SubmitMailError::Other(e.to_string()))?; + match resp { + SmtpResponse::Ok(_) => { + debug!(%canister_id, %to_path, "smtp_request accepted"); + } + SmtpResponse::Err(err) => { + let (code, msg) = smtp_line_from_canister_err(&err); + warn!(%canister_id, %to_path, canister_code = %err.code, "smtp_request rejected"); + return Err(SubmitMailError::Rejected { + code, + message: msg, + failed_recipient: to_path.clone(), + }); + } + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_simple_message() { + let raw = b"From: a@b\r\nTo: c@d\r\n\r\nhello"; + let m = parse_rfc5322_message(raw).unwrap(); + assert_eq!(m.headers.len(), 2); + assert_eq!(m.body, b"hello"); + } + + #[test] + fn unfold_continuation() { + let block = "Subject: very\r\n long\r\n line"; + let lines = unfold_header_block(block); + assert_eq!(lines.len(), 1); + assert!(lines[0].contains("very long line")); + } +} diff --git a/ic-bn-lib/src/smtp/mod.rs b/ic-bn-lib/src/smtp/mod.rs new file mode 100644 index 0000000..4c484c7 --- /dev/null +++ b/ic-bn-lib/src/smtp/mod.rs @@ -0,0 +1,3 @@ +pub mod address; +pub mod ic; +pub mod session; diff --git a/ic-bn-lib/src/smtp/session/helo.rs b/ic-bn-lib/src/smtp/session/helo.rs new file mode 100644 index 0000000..02b4d08 --- /dev/null +++ b/ic-bn-lib/src/smtp/session/helo.rs @@ -0,0 +1,33 @@ +use std::{io, str::FromStr}; + +use fqdn::FQDN; +use smtp_proto::{ + EXT_8BIT_MIME, EXT_ENHANCED_STATUS_CODES, EXT_SMTP_UTF8, EXT_START_TLS, EhloResponse, +}; + +use crate::{network::AsyncReadWrite, smtp::session::Session}; + +impl Session { + /// Handles HELO/EHLO messages + pub async fn handle_helo(&mut self, domain: &str) -> io::Result<()> { + // Validate hostname + let Ok(helo_hostname) = FQDN::from_str(domain) else { + return self.write(b"550 5.5.0 Invalid EHLO hostname.\r\n").await; + }; + + self.data.helo_hostname = Some(helo_hostname); + + let mut response = EhloResponse::new(self.hostname.as_str()); + response.capabilities = EXT_ENHANCED_STATUS_CODES | EXT_8BIT_MIME | EXT_SMTP_UTF8; + + // Send STARTTLS cap only if we support TLS & we're not already in TLS mode + if !self.in_starttls && self.tls_config.is_some() { + response.capabilities |= EXT_START_TLS; + } + + let mut buf = Vec::with_capacity(128); + response.write(&mut buf).ok(); + + self.write(&buf).await + } +} diff --git a/ic-bn-lib/src/smtp/session/mod.rs b/ic-bn-lib/src/smtp/session/mod.rs new file mode 100644 index 0000000..d36cad4 --- /dev/null +++ b/ic-bn-lib/src/smtp/session/mod.rs @@ -0,0 +1,98 @@ +pub mod helo; + +use std::{io, net::IpAddr}; + +use fqdn::FQDN; +use mail_auth::MessageAuthenticator; +use rustls::ServerConfig; +use smtp_proto::request::receiver::{ + BdatReceiver, DataReceiver, DummyDataReceiver, DummyLineReceiver, LineReceiver, RequestReceiver, +}; +use strum::Display; +use tokio::io::AsyncWriteExt; +use uuid::Uuid; + +use crate::{network::AsyncReadWrite, smtp::address::EmailAddress}; + +/// SMTP session state +#[derive(Default, Display)] +pub enum SessionState { + #[default] + Init, + Request(RequestReceiver), + Data(DataReceiver), + Done, +} + +/// SMTP session data +#[derive(Debug, Default)] +pub struct SessionData { + pub helo_hostname: Option, + pub mail_from: Option, + pub rcpt_to: Option, +} + +/// SMTP session params +#[derive(Debug, Default)] +pub struct SessionParams { + pub max_message_size: u64, + pub verify_ehlo_hostname: bool, + pub verify_spf: bool, + pub verify_reverse_ip: bool, + pub dns_servers: Vec, +} + +pub struct Session { + id: Uuid, + hostname: String, + remote_ip: IpAddr, + stream: S, + state: SessionState, + data: SessionData, + params: SessionParams, + authenticator: MessageAuthenticator, + in_starttls: bool, + tls_config: Option, +} + +impl Session { + pub fn new( + hostname: String, + remote_ip: IpAddr, + stream: S, + params: SessionParams, + authenticator: MessageAuthenticator, + tls_config: Option, + ) -> Self { + Self { + id: Uuid::now_v7(), + hostname, + remote_ip, + stream, + state: SessionState::default(), + data: SessionData::default(), + params, + authenticator, + in_starttls: false, + tls_config, + } + } + + /// Writes given bytes to the session & flushes the buffer + pub async fn write(&mut self, bytes: &[u8]) -> io::Result<()> { + self.stream.write_all(bytes).await?; + self.stream.flush().await?; + Ok(()) + } + + pub async fn process(&mut self) { + match self.state { + SessionState::Init => { + let greeting = format!("220 {} ESMTP IC SMTP Gateway\r\n", self.hostname); + self.write(greeting.as_bytes()).await; + } + + _ => {} + } + } +} From c73e8b299e064180fac21195c605c73de42c2644 Mon Sep 17 00:00:00 2001 From: Igor Novgorodov Date: Wed, 13 May 2026 21:25:00 +0200 Subject: [PATCH 2/8] Work in progress --- ic-bn-lib/src/lib.rs | 1 + ic-bn-lib/src/smtp/mod.rs | 31 ++++++ ic-bn-lib/src/smtp/session/helo.rs | 38 +++++++- ic-bn-lib/src/smtp/session/mail_from.rs | 58 +++++++++++ ic-bn-lib/src/smtp/session/mod.rs | 124 +++++++++++++++++++++--- ic-bn-lib/src/smtp/session/rcpt_to.rs | 61 ++++++++++++ 6 files changed, 299 insertions(+), 14 deletions(-) create mode 100644 ic-bn-lib/src/smtp/session/mail_from.rs create mode 100644 ic-bn-lib/src/smtp/session/rcpt_to.rs diff --git a/ic-bn-lib/src/lib.rs b/ic-bn-lib/src/lib.rs index 2fd9395..02e7f5b 100644 --- a/ic-bn-lib/src/lib.rs +++ b/ic-bn-lib/src/lib.rs @@ -4,6 +4,7 @@ #![warn(tail_expr_drop_order)] #![allow(clippy::cognitive_complexity)] #![allow(clippy::field_reassign_with_default)] +#![allow(clippy::collapsible_if)] #[cfg(feature = "custom-domains")] pub mod custom_domains; diff --git a/ic-bn-lib/src/smtp/mod.rs b/ic-bn-lib/src/smtp/mod.rs index 4c484c7..caf342b 100644 --- a/ic-bn-lib/src/smtp/mod.rs +++ b/ic-bn-lib/src/smtp/mod.rs @@ -1,3 +1,34 @@ +use async_trait::async_trait; + +use crate::smtp::address::EmailAddress; + pub mod address; pub mod ic; pub mod session; + +/// Recipient resolution policy +pub enum RecipientPolicy { + Accept, + Rewrite(EmailAddress), + Expand(Vec), +} + +/// Recipient resolution error +#[derive(thiserror::Error, Debug)] +pub enum RecipientResolveError { + #[error("Unknown recipient")] + UnknownRecipient, + #[error("Unknown domain")] + UnknownDomain, + #[error("{0}")] + Other(String), +} + +/// Looks up the given recipient & applies `RecipientPolicy` policy +#[async_trait] +pub trait ResolvesRecipient: Send + Sync { + async fn resolve_recipient( + &self, + rcpt: &EmailAddress, + ) -> Result; +} diff --git a/ic-bn-lib/src/smtp/session/helo.rs b/ic-bn-lib/src/smtp/session/helo.rs index 02b4d08..b9b0849 100644 --- a/ic-bn-lib/src/smtp/session/helo.rs +++ b/ic-bn-lib/src/smtp/session/helo.rs @@ -1,21 +1,57 @@ use std::{io, str::FromStr}; +use ahash::AHashSet; use fqdn::FQDN; +use mail_auth::hickory_resolver::proto::{ProtoErrorKind, rr::RecordType}; use smtp_proto::{ EXT_8BIT_MIME, EXT_ENHANCED_STATUS_CODES, EXT_SMTP_UTF8, EXT_START_TLS, EhloResponse, }; +use tracing::info; use crate::{network::AsyncReadWrite, smtp::session::Session}; impl Session { - /// Handles HELO/EHLO messages + /// Handles HELO/EHLO commands pub async fn handle_helo(&mut self, domain: &str) -> io::Result<()> { // Validate hostname let Ok(helo_hostname) = FQDN::from_str(domain) else { return self.write(b"550 5.5.0 Invalid EHLO hostname.\r\n").await; }; + if helo_hostname.depth() < 2 { + return self + .write(b"550 5.5.0 EHLO hostname must be an FQDN.\r\n") + .await; + }; + + // Check if EHLO hostname resolves if configured + if self.params.verify_ehlo_hostname { + match self.authenticator.resolver().lookup_ip(domain).await { + Ok(v) => { + if v.iter().next().is_none() { + return self + .write(b"550 5.5.0 EHLO hostname not found in DNS.\r\n") + .await; + } + } + + Err(e) => { + if matches!(e.kind(), ProtoErrorKind::NoRecordsFound(_)) { + return self + .write(b"550 5.5.0 EHLO hostname not found in DNS.\r\n") + .await; + } + + info!("Unable to lookup '{domain}' in DNS: {e:#}"); + return self + .write(b"451 4.7.25 Temporary error validating EHLO hostname.\r\n") + .await; + } + } + } self.data.helo_hostname = Some(helo_hostname); + self.data.mail_from = None; + self.data.rcpt_to = AHashSet::new(); let mut response = EhloResponse::new(self.hostname.as_str()); response.capabilities = EXT_ENHANCED_STATUS_CODES | EXT_8BIT_MIME | EXT_SMTP_UTF8; diff --git a/ic-bn-lib/src/smtp/session/mail_from.rs b/ic-bn-lib/src/smtp/session/mail_from.rs new file mode 100644 index 0000000..0aba1a5 --- /dev/null +++ b/ic-bn-lib/src/smtp/session/mail_from.rs @@ -0,0 +1,58 @@ +use std::{borrow::Cow, io, str::FromStr}; + +use fqdn::FQDN; +use mail_auth::{IprevResult, Parameters}; +use smtp_proto::{ + EXT_8BIT_MIME, EXT_ENHANCED_STATUS_CODES, EXT_SMTP_UTF8, EXT_START_TLS, EhloResponse, MailFrom, +}; + +use crate::{ + network::AsyncReadWrite, + smtp::{address::EmailAddress, session::Session}, +}; + +impl Session { + /// Handles MAIL FROM command + pub async fn handle_mail_from(&mut self, from: MailFrom>) -> io::Result<()> { + if self.data.helo_hostname.is_none() { + return self + .write(b"503 5.5.1 Polite people say EHLO first.\r\n") + .await; + } + + if self.data.mail_from.is_some() { + return self + .write(b"503 5.5.1 Multiple MAIL FROM commands not allowed.\r\n") + .await; + } + + // Validate address + let Ok(address) = EmailAddress::from_str(&from.address) else { + return self + .write(b"550 5.7.1 Sender address is incorrect.\r\n") + .await; + }; + + // Validate reverse IP if configured + if self.params.verify_reverse_ip { + let result = self + .authenticator + .verify_iprev(Parameters::from(self.remote_ip)) + .await + .result; + + if !matches!(result, IprevResult::Pass) { + let message = if matches!(result, IprevResult::TempError(_)) { + &b"451 4.7.25 Temporary error validating reverse DNS.\r\n"[..] + } else { + &b"550 5.7.25 Reverse DNS validation failed.\r\n"[..] + }; + + return self.write(message).await; + } + } + + self.data.mail_from = Some(address); + Ok(()) + } +} diff --git a/ic-bn-lib/src/smtp/session/mod.rs b/ic-bn-lib/src/smtp/session/mod.rs index d36cad4..57370fb 100644 --- a/ic-bn-lib/src/smtp/session/mod.rs +++ b/ic-bn-lib/src/smtp/session/mod.rs @@ -1,7 +1,10 @@ pub mod helo; +pub mod mail_from; +pub mod rcpt_to; -use std::{io, net::IpAddr}; +use std::{io, net::IpAddr, sync::Arc, time::Duration}; +use ahash::AHashSet; use fqdn::FQDN; use mail_auth::MessageAuthenticator; use rustls::ServerConfig; @@ -9,10 +12,19 @@ use smtp_proto::request::receiver::{ BdatReceiver, DataReceiver, DummyDataReceiver, DummyLineReceiver, LineReceiver, RequestReceiver, }; use strum::Display; -use tokio::io::AsyncWriteExt; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + select, + time::{error::Elapsed, timeout}, +}; +use tokio_util::sync::CancellationToken; +use tracing::{info, warn}; use uuid::Uuid; -use crate::{network::AsyncReadWrite, smtp::address::EmailAddress}; +use crate::{ + network::AsyncReadWrite, + smtp::{ResolvesRecipient, address::EmailAddress}, +}; /// SMTP session state #[derive(Default, Display)] @@ -29,19 +41,38 @@ pub enum SessionState { pub struct SessionData { pub helo_hostname: Option, pub mail_from: Option, - pub rcpt_to: Option, + pub rcpt_to: AHashSet, } -/// SMTP session params -#[derive(Debug, Default)] +/// SMTP session parameters +#[derive(Debug)] pub struct SessionParams { pub max_message_size: u64, + pub max_recipients: usize, + pub max_session_duration: Duration, pub verify_ehlo_hostname: bool, pub verify_spf: bool, pub verify_reverse_ip: bool, - pub dns_servers: Vec, + pub helo_delay: Duration, + pub timeout: Duration, +} + +impl Default for SessionParams { + fn default() -> Self { + Self { + max_message_size: 10 * 1024 * 1024, + max_recipients: 5, + max_session_duration: Duration::from_secs(600), + verify_ehlo_hostname: false, + verify_reverse_ip: false, + verify_spf: false, + helo_delay: Duration::from_secs(3), + timeout: Duration::from_secs(30), + } + } } +/// SMTP Session pub struct Session { id: Uuid, hostname: String, @@ -50,19 +81,24 @@ pub struct Session { state: SessionState, data: SessionData, params: SessionParams, - authenticator: MessageAuthenticator, + authenticator: Arc, + recipient_resolver: Arc, in_starttls: bool, tls_config: Option, + shutdown_token: CancellationToken, } +#[allow(clippy::too_many_arguments)] impl Session { pub fn new( hostname: String, remote_ip: IpAddr, stream: S, params: SessionParams, - authenticator: MessageAuthenticator, + authenticator: Arc, + recipient_resolver: Arc, tls_config: Option, + shutdown_token: CancellationToken, ) -> Self { Self { id: Uuid::now_v7(), @@ -73,8 +109,10 @@ impl Session { data: SessionData::default(), params, authenticator, + recipient_resolver, in_starttls: false, tls_config, + shutdown_token, } } @@ -85,14 +123,74 @@ impl Session { Ok(()) } - pub async fn process(&mut self) { - match self.state { + async fn ingest(&mut self, bytes: &[u8]) -> io::Result<()> { + match &self.state { SessionState::Init => { + // If we have HELO delay configured - try to read from the stream for up to this duration. + // The client needs to wait silently until we send our greeting. + // If something comes in - then the client isn't respecting the protocol and we consider him malicious. + if self.params.helo_delay != Duration::ZERO { + if timeout(self.params.helo_delay, self.stream.read_u8()) + .await + .is_ok() + { + self.write(b"").await?; + } + } + let greeting = format!("220 {} ESMTP IC SMTP Gateway\r\n", self.hostname); - self.write(greeting.as_bytes()).await; + self.write(greeting.as_bytes()).await?; + + self.state = SessionState::Request(RequestReceiver::default()) } - _ => {} + SessionState::Request(rx) => {} + SessionState::Data(rx) => {} + SessionState::Done => { + self.stream.shutdown().await.ok(); + } } + + Ok(()) + } + + pub async fn read( + &mut self, + buf: &mut [u8], + res: Result, Elapsed>, + ) -> io::Result<()> { + match res { + Ok(Ok(bytes_read)) => { + self.ingest(&buf[..bytes_read]).await?; + } + Ok(Err(e)) => { + return Err(e); + } + Err(e) => return Err(io::Error::other(e)), + } + + Ok(()) + } + + /// Drives the session forward + pub async fn handle(&mut self) { + let mut buf = vec![0; 8192]; + + loop { + select! { + res = timeout(self.params.timeout, self.stream.read(&mut buf)) => { + if let Err(e) = self.read(&mut buf, res).await { + info!("Session error: {e:#}"); + break; + }; + }, + + () = self.shutdown_token.cancelled() => { + break; + } + } + } + + self.stream.shutdown().await.ok(); } } diff --git a/ic-bn-lib/src/smtp/session/rcpt_to.rs b/ic-bn-lib/src/smtp/session/rcpt_to.rs new file mode 100644 index 0000000..655b396 --- /dev/null +++ b/ic-bn-lib/src/smtp/session/rcpt_to.rs @@ -0,0 +1,61 @@ +use std::{borrow::Cow, io, str::FromStr}; + +use smtp_proto::RcptTo; + +use crate::{ + network::AsyncReadWrite, + smtp::{RecipientPolicy, RecipientResolveError, address::EmailAddress, session::Session}, +}; + +impl Session { + /// Handles RCPT TO command + pub async fn handle_rcpt_to(&mut self, to: RcptTo>) -> io::Result<()> { + if self.data.mail_from.is_none() { + return self + .write(b"503 5.5.1 MAIL FROM is required first.\r\n") + .await; + } + + if self.data.rcpt_to.len() >= self.params.max_recipients { + return self.write(b"455 4.5.3 Too many recipients.\r\n").await; + } + + let Ok(address) = EmailAddress::from_str(&to.address) else { + return self.write(b"550 5.1.2 Incorrect address.\r\n").await; + }; + + if self.data.rcpt_to.contains(&address) { + return self.write(b"250 2.1.5 OK\r\n").await; + } + + match self.recipient_resolver.resolve_recipient(&address).await { + Ok(v) => match v { + RecipientPolicy::Accept => { + self.data.rcpt_to.insert(address); + } + RecipientPolicy::Rewrite(new_address) => { + self.data.rcpt_to.insert(new_address); + } + RecipientPolicy::Expand(new_addresses) => { + self.data.rcpt_to.extend(new_addresses); + } + }, + + Err(e) => match e { + RecipientResolveError::UnknownDomain => { + return self.write(b"550 5.1.2 Relay not allowed.\r\n").await; + } + RecipientResolveError::UnknownRecipient => { + return self.write(b"550 5.1.2 Mailbox does not exist.\r\n").await; + } + RecipientResolveError::Other(_) => { + return self + .write(b"451 4.4.3 Unable to verify address at this time.\r\n") + .await; + } + }, + } + + self.write(b"250 2.1.5 OK\r\n").await + } +} From 5df087a7578967ccc27d27ab27ed0f036bc130d7 Mon Sep 17 00:00:00 2001 From: Igor Novgorodov Date: Fri, 15 May 2026 20:51:43 +0200 Subject: [PATCH 3/8] More SMTP work --- ic-bn-lib/Cargo.toml | 5 + ic-bn-lib/src/http/server/mod.rs | 1 - ic-bn-lib/src/network/listener.rs | 8 +- ic-bn-lib/src/network/mod.rs | 11 +- ic-bn-lib/src/smtp/address.rs | 4 + .../smtp/{session/helo.rs => inbound/ehlo.rs} | 34 +- ic-bn-lib/src/smtp/inbound/mail_from.rs | 110 +++++ ic-bn-lib/src/smtp/inbound/manager.rs | 85 ++++ ic-bn-lib/src/smtp/inbound/mod.rs | 167 +++++++ .../src/smtp/{session => inbound}/rcpt_to.rs | 30 +- ic-bn-lib/src/smtp/inbound/session.rs | 426 ++++++++++++++++++ ic-bn-lib/src/smtp/mod.rs | 74 ++- ic-bn-lib/src/smtp/server.rs | 94 ++++ ic-bn-lib/src/smtp/session/mail_from.rs | 58 --- ic-bn-lib/src/smtp/session/mod.rs | 196 -------- ic-bn-lib/tools/smtp_server.rs | 20 + 16 files changed, 1038 insertions(+), 285 deletions(-) rename ic-bn-lib/src/smtp/{session/helo.rs => inbound/ehlo.rs} (60%) create mode 100644 ic-bn-lib/src/smtp/inbound/mail_from.rs create mode 100644 ic-bn-lib/src/smtp/inbound/manager.rs create mode 100644 ic-bn-lib/src/smtp/inbound/mod.rs rename ic-bn-lib/src/smtp/{session => inbound}/rcpt_to.rs (71%) create mode 100644 ic-bn-lib/src/smtp/inbound/session.rs create mode 100644 ic-bn-lib/src/smtp/server.rs delete mode 100644 ic-bn-lib/src/smtp/session/mail_from.rs delete mode 100644 ic-bn-lib/src/smtp/session/mod.rs create mode 100644 ic-bn-lib/tools/smtp_server.rs diff --git a/ic-bn-lib/Cargo.toml b/ic-bn-lib/Cargo.toml index 30eca4d..a47330f 100644 --- a/ic-bn-lib/Cargo.toml +++ b/ic-bn-lib/Cargo.toml @@ -117,6 +117,7 @@ systemstat = "0.2.3" tar = { version = "0.4.44", optional = true } tempfile = "3.23" thiserror = { workspace = true } +tracing-subscriber = "0.3" tokio = { version = "1.47.0", features = ["full"] } tokio-util = { workspace = true } tokio-rustls = { version = "0.26.0", default-features = false, features = [ @@ -157,3 +158,7 @@ required-features = ["vector"] [package.metadata.cargo-all-features] # Limit feature combinations to reduce test duration max_combination_size = 2 + +[[bin]] +name = "smtp-server" +path = "tools/smtp_server.rs" diff --git a/ic-bn-lib/src/http/server/mod.rs b/ic-bn-lib/src/http/server/mod.rs index 7e89d9d..f7be719 100644 --- a/ic-bn-lib/src/http/server/mod.rs +++ b/ic-bn-lib/src/http/server/mod.rs @@ -44,7 +44,6 @@ use tokio::{ time::{sleep, timeout}, }; use tokio_io_timeout::TimeoutStream; -use tokio_rustls::TlsAcceptor; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tower_service::Service; use tracing::{debug, info, warn}; diff --git a/ic-bn-lib/src/network/listener.rs b/ic-bn-lib/src/network/listener.rs index 776e980..b5403d8 100644 --- a/ic-bn-lib/src/network/listener.rs +++ b/ic-bn-lib/src/network/listener.rs @@ -14,7 +14,7 @@ pub enum Listener { impl Listener { /// Create a new Listener - pub fn new(addr: Addr, opts: ListenerOpts) -> Result { + pub fn new(addr: Addr, opts: ListenerOpts) -> io::Result { Ok(match addr { Addr::Tcp(v) => Self::Tcp(listen_tcp(v, opts)?), Addr::Unix(v) => Self::Unix(listen_unix(v, opts)?), @@ -22,7 +22,7 @@ impl Listener { } /// Accept the connection - pub async fn accept(&self) -> Result<(Box, Addr), io::Error> { + pub async fn accept(&self) -> io::Result<(Box, Addr)> { Ok(match self { Self::Tcp(v) => { let x = v.accept().await?; @@ -61,7 +61,7 @@ impl From for Listener { } /// Creates a TCP listener with given opts -pub fn listen_tcp(addr: SocketAddr, opts: ListenerOpts) -> Result { +pub fn listen_tcp(addr: SocketAddr, opts: ListenerOpts) -> io::Result { let domain = if addr.is_ipv4() { Domain::IPV4 } else { @@ -85,7 +85,7 @@ pub fn listen_tcp(addr: SocketAddr, opts: ListenerOpts) -> Result Result { +pub fn listen_unix(path: PathBuf, opts: ListenerOpts) -> io::Result { let socket = UnixSocket::new_stream()?; if path.exists() { diff --git a/ic-bn-lib/src/network/mod.rs b/ic-bn-lib/src/network/mod.rs index 7f83397..bd39cc9 100644 --- a/ic-bn-lib/src/network/mod.rs +++ b/ic-bn-lib/src/network/mod.rs @@ -8,19 +8,19 @@ use std::{ use ic_bn_lib_common::types::http::{Stats, TlsInfo}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_rustls::TlsAcceptor; +use tokio_rustls::{TlsAcceptor, server::TlsStream}; pub mod listener; -/// Blanket async read+write trait for streams Box-ing +/// Blanket async read+write trait for streams `Box`-ing pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {} impl AsyncReadWrite for T {} /// Performs TLS handshake on the given stream -pub async fn tls_handshake( +pub async fn tls_handshake( rustls_cfg: Arc, - stream: impl AsyncReadWrite, -) -> Result<(impl AsyncReadWrite, TlsInfo), io::Error> { + stream: T, +) -> io::Result<(TlsStream, TlsInfo)> { let tls_acceptor = TlsAcceptor::from(rustls_cfg); // Perform the TLS handshake @@ -28,6 +28,7 @@ pub async fn tls_handshake( let stream = tls_acceptor.accept(stream).await?; let duration = start.elapsed(); + // Obtain TLS info let conn = stream.get_ref().1; let mut tls_info = TlsInfo::try_from(conn).map_err(io::Error::other)?; tls_info.handshake_dur = duration; diff --git a/ic-bn-lib/src/smtp/address.rs b/ic-bn-lib/src/smtp/address.rs index 893b3ef..4db283b 100644 --- a/ic-bn-lib/src/smtp/address.rs +++ b/ic-bn-lib/src/smtp/address.rs @@ -33,6 +33,10 @@ impl FromStr for EmailAddress { fn from_str(s: &str) -> Result { let (local, domain) = s.rsplit_once('@').ok_or(EmailAddressError::AtMissing)?; + if domain.is_empty() { + return Err(EmailAddressError::DomainIncorrect("Empty domain".into())); + } + let domain = FQDN::from_ascii_str(domain) .map_err(|e| EmailAddressError::DomainIncorrect(e.to_string()))?; diff --git a/ic-bn-lib/src/smtp/session/helo.rs b/ic-bn-lib/src/smtp/inbound/ehlo.rs similarity index 60% rename from ic-bn-lib/src/smtp/session/helo.rs rename to ic-bn-lib/src/smtp/inbound/ehlo.rs index b9b0849..a261e36 100644 --- a/ic-bn-lib/src/smtp/session/helo.rs +++ b/ic-bn-lib/src/smtp/inbound/ehlo.rs @@ -1,23 +1,27 @@ -use std::{io, str::FromStr}; +use std::str::FromStr; use ahash::AHashSet; use fqdn::FQDN; -use mail_auth::hickory_resolver::proto::{ProtoErrorKind, rr::RecordType}; +use mail_auth::hickory_resolver::proto::ProtoErrorKind; use smtp_proto::{ - EXT_8BIT_MIME, EXT_ENHANCED_STATUS_CODES, EXT_SMTP_UTF8, EXT_START_TLS, EhloResponse, + EXT_8BIT_MIME, EXT_CHUNKING, EXT_ENHANCED_STATUS_CODES, EXT_SMTP_UTF8, EXT_START_TLS, + EhloResponse, }; use tracing::info; -use crate::{network::AsyncReadWrite, smtp::session::Session}; +use crate::{ + network::AsyncReadWrite, + smtp::inbound::{Session, SessionResult}, +}; impl Session { - /// Handles HELO/EHLO commands - pub async fn handle_helo(&mut self, domain: &str) -> io::Result<()> { + /// Handles EHLO/HELO commands + pub async fn handle_ehlo(&mut self, host: &str) -> SessionResult<()> { // Validate hostname - let Ok(helo_hostname) = FQDN::from_str(domain) else { + let Ok(ehlo_hostname) = FQDN::from_str(host) else { return self.write(b"550 5.5.0 Invalid EHLO hostname.\r\n").await; }; - if helo_hostname.depth() < 2 { + if ehlo_hostname.depth() < 2 { return self .write(b"550 5.5.0 EHLO hostname must be an FQDN.\r\n") .await; @@ -25,7 +29,7 @@ impl Session { // Check if EHLO hostname resolves if configured if self.params.verify_ehlo_hostname { - match self.authenticator.resolver().lookup_ip(domain).await { + match self.params.authenticator.resolver().lookup_ip(host).await { Ok(v) => { if v.iter().next().is_none() { return self @@ -41,7 +45,7 @@ impl Session { .await; } - info!("Unable to lookup '{domain}' in DNS: {e:#}"); + info!("{self}: Unable to lookup '{host}' in DNS: {e:#}"); return self .write(b"451 4.7.25 Temporary error validating EHLO hostname.\r\n") .await; @@ -49,15 +53,17 @@ impl Session { } } - self.data.helo_hostname = Some(helo_hostname); + self.data.ehlo_hostname = Some(ehlo_hostname); self.data.mail_from = None; self.data.rcpt_to = AHashSet::new(); - let mut response = EhloResponse::new(self.hostname.as_str()); - response.capabilities = EXT_ENHANCED_STATUS_CODES | EXT_8BIT_MIME | EXT_SMTP_UTF8; + let mut response = EhloResponse::new(self.params.hostname.as_str()); + response.capabilities = + EXT_ENHANCED_STATUS_CODES | EXT_8BIT_MIME | EXT_SMTP_UTF8 | EXT_CHUNKING; + response.size = self.params.max_message_size; // Send STARTTLS cap only if we support TLS & we're not already in TLS mode - if !self.in_starttls && self.tls_config.is_some() { + if self.tls_info.is_none() && self.params.tls_config.is_some() { response.capabilities |= EXT_START_TLS; } diff --git a/ic-bn-lib/src/smtp/inbound/mail_from.rs b/ic-bn-lib/src/smtp/inbound/mail_from.rs new file mode 100644 index 0000000..a45fa16 --- /dev/null +++ b/ic-bn-lib/src/smtp/inbound/mail_from.rs @@ -0,0 +1,110 @@ +use std::{borrow::Cow, fmt::Write, str::FromStr}; + +use mail_auth::{IprevResult, Parameters, SpfResult, spf::verify::SpfParameters}; +use smtp_proto::{MAIL_BY_NOTIFY, MAIL_BY_RETURN, MailFrom}; + +use crate::{ + network::AsyncReadWrite, + smtp::{ + address::EmailAddress, + inbound::{Session, SessionResult}, + }, +}; + +impl Session { + /// Handles MAIL FROM command + pub async fn handle_mail_from(&mut self, from: MailFrom>) -> SessionResult<()> { + let Some(helo_hostname) = &self.data.ehlo_hostname else { + return self + .write(b"503 5.5.1 Polite people say EHLO first.\r\n") + .await; + }; + + if self.data.mail_from.is_some() { + return self + .write(b"503 5.5.1 Multiple MAIL FROM commands are not allowed.\r\n") + .await; + } + + if (from.flags & (MAIL_BY_NOTIFY | MAIL_BY_RETURN)) != 0 { + return self.ext_unsupported("DELIVERBY").await; + } + + if from.mt_priority != 0 { + return self.ext_unsupported("MT-PRIORITY").await; + } + + if from.size > 0 && from.size > self.params.max_message_size { + return self.message_too_big().await; + } + + if from.hold_for != 0 || from.hold_until != 0 { + return self.ext_unsupported("FUTURERELEASE").await; + } + + if from.env_id.is_some() { + return self.ext_unsupported("DSN").await; + } + + // Validate address + let Ok(address) = EmailAddress::from_str(&from.address) else { + return self + .write(b"550 5.7.1 Sender address is incorrect.\r\n") + .await; + }; + + // Validate reverse IP if configured + if self.params.verify_reverse_ip { + let result = self + .params + .authenticator + .verify_iprev(Parameters::from(self.remote_ip)) + .await + .result; + + if !matches!(result, IprevResult::Pass) { + let message = if matches!(result, IprevResult::TempError(_)) { + &b"451 4.7.25 Temporary error validating reverse DNS.\r\n"[..] + } else { + &b"550 5.7.25 Reverse DNS validation failed.\r\n"[..] + }; + + return self.write(message).await; + } + } + + if self.params.verify_spf { + let output = self + .params + .authenticator + .verify_spf(SpfParameters::verify_mail_from( + self.remote_ip, + &helo_hostname.to_string(), + &self.params.hostname, + &from.address, + )) + .await; + + match output.result() { + SpfResult::Pass | SpfResult::Neutral | SpfResult::None => {} + SpfResult::TempError => { + return self + .write(b"451 4.7.24 Temporary SPF validation error.\r\n") + .await; + } + SpfResult::Fail | SpfResult::PermError | SpfResult::SoftFail => { + let mut msg = "550 5.7.23 SPF validation failed".to_string(); + if let Some(v) = output.explanation() { + write!(msg, ": {v}.").ok(); + } + write!(msg, "\r\n").ok(); + return self.write(msg.as_bytes()).await; + } + } + } + + self.write(b"250 2.1.0 OK\r\n").await?; + self.data.mail_from = Some(address); + Ok(()) + } +} diff --git a/ic-bn-lib/src/smtp/inbound/manager.rs b/ic-bn-lib/src/smtp/inbound/manager.rs new file mode 100644 index 0000000..aa33234 --- /dev/null +++ b/ic-bn-lib/src/smtp/inbound/manager.rs @@ -0,0 +1,85 @@ +use std::{net::SocketAddr, sync::Arc}; + +use derive_new::new; +use tokio::io::AsyncWriteExt; +use tokio_rustls::server::TlsStream; +use tokio_util::sync::CancellationToken; +use tracing::{debug, info}; + +use crate::{ + network::{AsyncReadWrite, tls_handshake}, + smtp::inbound::{Session, SessionData, SessionParams, SessionResult, SessionUpgrade}, +}; + +/// Manages the lifetime of a single SMTP session. +/// +/// Needed because the SMTP session can transition into TLS state +/// which requires external orchestration. +#[derive(new)] +pub struct SessionManager; + +impl SessionManager { + pub async fn handle_connection( + &self, + stream: S, + remote_addr: SocketAddr, + params: Arc, + shutdown_token: CancellationToken, + ) { + let mut session = Session::new(remote_addr.ip(), stream, params); + + match session.handle(shutdown_token.child_token()).await { + Ok(v) => match v { + SessionUpgrade::No => { + session.stream.shutdown().await.ok(); + } + + SessionUpgrade::StartTls => { + let log_name = session.to_string(); + match session.into_tls().await { + Ok(mut session) => { + if let Err(e) = session.handle(shutdown_token.child_token()).await { + info!("{session}: error: {e:#}"); + session.stream.shutdown().await.ok(); + } + } + Err(e) => { + info!("{log_name}: TLS handshake failed: {e:#}"); + } + }; + } + }, + + Err(e) => { + info!("{session}: error: {e:#}, closing connection"); + if let Err(e) = session.shutdown().await { + debug!("{session}: error closing connection: {e:#}"); + }; + } + } + } +} + +impl Session { + /// Converts the plain-text session into a TLS one by doing a TLS handshake + pub async fn into_tls(self) -> SessionResult>> { + // SAFETY: Code makes sure that we end up here only if tls_config is Some. + // If we ever panic here - it should mean that the core logic is flawed. + let (stream, tls_info) = + tls_handshake(self.params.tls_config.clone().unwrap(), self.stream).await?; + + Ok(Session { + id: self.id, + remote_ip: self.remote_ip, + stream, + state: self.state, + // According to the RFC we need to discard all session data + // after switching into TLS mode. + // https://datatracker.ietf.org/doc/html/rfc3207#section-4.2 + data: SessionData::default(), + counters: self.counters, + params: self.params, + tls_info: Some(tls_info), + }) + } +} diff --git a/ic-bn-lib/src/smtp/inbound/mod.rs b/ic-bn-lib/src/smtp/inbound/mod.rs new file mode 100644 index 0000000..2bafdd9 --- /dev/null +++ b/ic-bn-lib/src/smtp/inbound/mod.rs @@ -0,0 +1,167 @@ +pub mod ehlo; +pub mod mail_from; +pub mod manager; +pub mod rcpt_to; +pub mod session; + +use std::{ + io, + net::IpAddr, + sync::Arc, + time::{Duration, Instant}, +}; + +use ahash::AHashSet; +use fqdn::FQDN; +use ic_bn_lib_common::types::http::TlsInfo; +use mail_auth::MessageAuthenticator; +use rustls::ServerConfig; +use smtp_proto::{ + Error as SmtpError, + request::receiver::{ + BdatReceiver, DataReceiver, DummyDataReceiver, DummyLineReceiver, RequestReceiver, + }, +}; +use strum::Display; +use uuid::Uuid; + +use crate::{ + network::AsyncReadWrite, + smtp::{ + DeliversMail, DummyDeliveryAgent, DummyRecipientResolver, ResolvesRecipient, + address::EmailAddress, + }, +}; + +#[derive(thiserror::Error, Debug)] +pub enum SessionError { + #[error("I/O error: {0}")] + Io(#[from] io::Error), + #[error("Timed out")] + Timeout, + #[error("{0}")] + ProtocolError(String), + #[error("{0}")] + SmtpError(#[from] SmtpError), + #[error("Session terminated by client (QUIT)")] + Quit, + #[error("Too many messages per session")] + TooManyMessagesPerSession, + #[error("Session transfer quota ({0}) was exceeded")] + TransferQuotaExceeded(usize), + #[error("Session TTL ({0}s) was exceeded")] + TtlExceeded(u64), + #[error("{0}")] + Other(#[from] anyhow::Error), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionUpgrade { + No, + StartTls, +} + +pub type SessionResult = Result; + +/// SMTP session parameters +pub struct SessionParams { + pub hostname: String, + pub max_message_size: usize, + pub max_recipients: usize, + pub max_session_duration: Duration, + pub max_session_data: usize, + pub max_errors: usize, + pub max_messages_per_session: usize, + pub verify_ehlo_hostname: bool, + pub verify_spf: bool, + pub verify_reverse_ip: bool, + pub helo_delay: Option, + pub timeout: Duration, + pub tls_config: Option>, + pub authenticator: Arc, + pub recipient_resolver: Arc, + pub delivery_agent: Arc, +} + +impl Default for SessionParams { + fn default() -> Self { + Self { + hostname: "".into(), + max_message_size: 10 * 1024 * 1024, + max_recipients: 5, + max_session_duration: Duration::from_secs(600), + max_session_data: 50 * 1024 * 1024, + max_errors: 5, + max_messages_per_session: 5, + verify_ehlo_hostname: false, + verify_reverse_ip: false, + verify_spf: false, + helo_delay: None, + timeout: Duration::from_secs(30), + tls_config: None, + // SAFETY: this never fails + authenticator: Arc::new(MessageAuthenticator::new_cloudflare().unwrap()), + recipient_resolver: Arc::new(DummyRecipientResolver), + delivery_agent: Arc::new(DummyDeliveryAgent), + } + } +} + +/// SMTP session state +#[derive(Display)] +pub enum SessionState { + Greeting, + Request(RequestReceiver), + Data(DataReceiver), + Bdat(BdatReceiver), + RequestTooLarge(DummyLineReceiver), + DataTooLarge(DummyDataReceiver), + None, +} + +impl Default for SessionState { + fn default() -> Self { + Self::Request(RequestReceiver::default()) + } +} + +/// SMTP session data +#[derive(Debug, Default)] +pub struct SessionData { + pub ehlo_hostname: Option, + pub mail_from: Option, + pub rcpt_to: AHashSet, + pub message: Vec, +} + +/// SMTP session counters +#[derive(Debug)] +pub struct SessionCounters { + valid_until: Instant, + bytes_ingested: usize, + messages_queued: usize, + errors: usize, +} + +impl SessionCounters { + fn new(ttl: Duration) -> Self { + Self { + valid_until: Instant::now() + ttl, + bytes_ingested: 0, + messages_queued: 0, + errors: 0, + } + } +} + +/// SMTP Session +pub struct Session { + id: Uuid, + remote_ip: IpAddr, + stream: S, + state: SessionState, + data: SessionData, + counters: SessionCounters, + params: Arc, + tls_info: Option, +} diff --git a/ic-bn-lib/src/smtp/session/rcpt_to.rs b/ic-bn-lib/src/smtp/inbound/rcpt_to.rs similarity index 71% rename from ic-bn-lib/src/smtp/session/rcpt_to.rs rename to ic-bn-lib/src/smtp/inbound/rcpt_to.rs index 655b396..188eff0 100644 --- a/ic-bn-lib/src/smtp/session/rcpt_to.rs +++ b/ic-bn-lib/src/smtp/inbound/rcpt_to.rs @@ -1,15 +1,21 @@ -use std::{borrow::Cow, io, str::FromStr}; +use std::{borrow::Cow, str::FromStr}; -use smtp_proto::RcptTo; +use smtp_proto::{ + RCPT_NOTIFY_DELAY, RCPT_NOTIFY_FAILURE, RCPT_NOTIFY_NEVER, RCPT_NOTIFY_SUCCESS, RcptTo, +}; use crate::{ network::AsyncReadWrite, - smtp::{RecipientPolicy, RecipientResolveError, address::EmailAddress, session::Session}, + smtp::{ + RecipientPolicy, RecipientResolveError, + address::EmailAddress, + inbound::{Session, SessionResult}, + }, }; impl Session { /// Handles RCPT TO command - pub async fn handle_rcpt_to(&mut self, to: RcptTo>) -> io::Result<()> { + pub async fn handle_rcpt_to(&mut self, to: RcptTo>) -> SessionResult<()> { if self.data.mail_from.is_none() { return self .write(b"503 5.5.1 MAIL FROM is required first.\r\n") @@ -20,6 +26,15 @@ impl Session { return self.write(b"455 4.5.3 Too many recipients.\r\n").await; } + // Check if DSN-related stuff was requested + if (to.flags + & (RCPT_NOTIFY_DELAY | RCPT_NOTIFY_NEVER | RCPT_NOTIFY_SUCCESS | RCPT_NOTIFY_FAILURE)) + != 0 + || to.orcpt.is_some() + { + return self.ext_unsupported("DSN").await; + } + let Ok(address) = EmailAddress::from_str(&to.address) else { return self.write(b"550 5.1.2 Incorrect address.\r\n").await; }; @@ -28,7 +43,12 @@ impl Session { return self.write(b"250 2.1.5 OK\r\n").await; } - match self.recipient_resolver.resolve_recipient(&address).await { + match self + .params + .recipient_resolver + .resolve_recipient(&address) + .await + { Ok(v) => match v { RecipientPolicy::Accept => { self.data.rcpt_to.insert(address); diff --git a/ic-bn-lib/src/smtp/inbound/session.rs b/ic-bn-lib/src/smtp/inbound/session.rs new file mode 100644 index 0000000..ea0771f --- /dev/null +++ b/ic-bn-lib/src/smtp/inbound/session.rs @@ -0,0 +1,426 @@ +use std::{ + borrow::Cow, + fmt::Display, + net::IpAddr, + sync::Arc, + time::{Duration, Instant}, +}; + +use anyhow::Context; +use smtp_proto::{ + Error as SmtpError, Request, + request::receiver::{BdatReceiver, DataReceiver, DummyDataReceiver, DummyLineReceiver}, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + select, +}; +use tokio_util::{sync::CancellationToken, time::FutureExt}; +use tracing::info; +use uuid::Uuid; + +use crate::{ + network::AsyncReadWrite, + smtp::{ + DeliveryError, Message, + inbound::{ + Session, SessionCounters, SessionData, SessionError, SessionParams, SessionResult, + SessionState, SessionUpgrade, + }, + }, +}; + +impl Display for Session { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SMTP({}): ", self.remote_ip) + } +} + +#[allow(clippy::too_many_arguments)] +impl Session { + pub fn new(remote_ip: IpAddr, stream: S, params: Arc) -> Self { + Self { + id: Uuid::now_v7(), + remote_ip, + stream, + state: SessionState::Greeting, + data: SessionData::default(), + counters: SessionCounters::new(params.max_session_duration), + params, + tls_info: None, + } + } + + /// Writes given bytes to the session & flushes the buffer + pub async fn write(&mut self, bytes: &[u8]) -> SessionResult<()> { + self.stream.write_all(bytes).await?; + self.stream.flush().await?; + Ok(()) + } + + /// Sends greeting message + async fn greeting(&mut self) -> SessionResult<()> { + // If we have HELO delay configured - try to read from the stream for up to this duration. + // The client needs to wait silently until we send our greeting. + // If something comes in - then the client isn't respecting the protocol, + // we consider him malicious and drop the connection. + if let Some(v) = self.params.helo_delay { + match self.stream.read_u8().timeout(v).await { + Ok(Ok(_)) => { + self.write(b"501 5.7.1 Client sent command before greeting banner.\r\n") + .await?; + return Err(SessionError::ProtocolError( + "Command before greeting banner".into(), + )); + } + Ok(Err(e)) => return Err(e.into()), + Err(_) => {} + } + } + + let greeting = format!("220 {} ESMTP IC SMTP Gateway\r\n", self.params.hostname); + self.write(greeting.as_bytes()).await?; + + Ok(()) + } + + async fn handle_error(&mut self, error: SmtpError) -> SessionResult<()> { + match error { + SmtpError::UnknownCommand | SmtpError::InvalidResponse { .. } => { + self.write(b"500 5.5.1 Invalid command.\r\n").await?; + } + SmtpError::InvalidSenderAddress => { + self.write(b"501 5.1.8 Bad sender's system address.\r\n") + .await?; + } + SmtpError::InvalidRecipientAddress => { + self.write(b"501 5.1.3 Bad destination mailbox address syntax.\r\n") + .await?; + } + SmtpError::SyntaxError { syntax } => { + self.write(format!("501 5.5.2 Syntax error, expected: {syntax}\r\n").as_bytes()) + .await?; + } + SmtpError::InvalidParameter { param } => { + self.write(format!("501 5.5.4 Invalid parameter {param:?}.\r\n").as_bytes()) + .await?; + } + SmtpError::UnsupportedParameter { param } => { + self.write(format!("504 5.5.4 Unsupported parameter {param:?}.\r\n").as_bytes()) + .await?; + } + SmtpError::ResponseTooLong => { + self.state = SessionState::RequestTooLarge(DummyLineReceiver::default()); + } + SmtpError::NeedsMoreData { .. } => {} + } + + Ok(()) + } + + async fn handle_request(&mut self, req: Request>) -> SessionResult<()> { + match req { + Request::Ehlo { host } | Request::Helo { host } => { + self.handle_ehlo(&host).await?; + } + Request::Mail { from } => { + self.handle_mail_from(from).await?; + } + Request::Rcpt { to } => { + self.handle_rcpt_to(to).await?; + } + Request::Rset => { + self.reset_message(); + self.write(b"250 2.0.0 OK\r\n").await?; + } + Request::Quit => { + self.write(b"221 2.0.0 Bye.\r\n").await?; + return Err(SessionError::Quit); + } + Request::Noop { .. } => { + self.write(b"250 2.0.0 OK\r\n").await?; + } + _ => { + self.write(b"502 5.5.1 Command not implemented.\r\n") + .await?; + self.counters.errors += 1; + } + } + + Ok(()) + } + + async fn ingest(&mut self, bytes: &[u8]) -> SessionResult { + // Check if we alread are over transfer quota + if self.counters.bytes_ingested + bytes.len() >= self.params.max_session_data { + self.write(b"452 4.7.28 Session transfer quota exceeded.\r\n") + .await?; + return Err(SessionError::TransferQuotaExceeded( + self.params.max_session_data, + )); + } + + // Check if we already are over session time quota + if Instant::now() > self.counters.valid_until { + self.write(b"452 4.3.2 Session open for too long.\r\n") + .await?; + return Err(SessionError::TtlExceeded( + self.params.max_session_duration.as_secs(), + )); + } + + self.counters.bytes_ingested += bytes.len(); + let mut iter = bytes.iter(); + // We can't take mutable ref to self.state & self at the same time, so we extract state + // temporarily + let mut state = std::mem::replace(&mut self.state, SessionState::None); + + loop { + match &mut state { + SessionState::Greeting => { + // This is handled separately + unreachable!(); + } + SessionState::Request(rx) => { + match rx.ingest(&mut iter) { + Ok(request) => match request { + // ASCII data + Request::Data => { + if self.can_accept_message().await? { + self.write(b"354 Start mail input; end with .\r\n") + .await?; + self.data.message = Vec::with_capacity(1024); + state = SessionState::Data(DataReceiver::new()); + continue; + } + } + // Binary data + Request::Bdat { + chunk_size, + is_last, + } => { + state = if self.data.message.len() + chunk_size + > self.params.max_message_size + { + SessionState::DataTooLarge(DummyDataReceiver::new_bdat( + chunk_size, + )) + } else { + // Allocate the needed capacity for the chunk + let free = + self.data.message.capacity() - self.data.message.len(); + if free < chunk_size { + self.data.message.reserve(chunk_size - free); + } + + SessionState::Bdat(BdatReceiver::new(chunk_size, is_last)) + } + } + Request::StartTls => { + if self.tls_info.is_some() { + self.write(b"504 5.7.4 Already in TLS mode.\r\n").await?; + self.counters.errors += 1; + } else if self.params.tls_config.is_none() { + self.write(b"502 5.7.0 TLS not available.\r\n").await?; + self.counters.errors += 1; + } else { + self.write(b"220 2.0.0 Ready to start TLS.\r\n").await?; + return Ok(SessionUpgrade::StartTls); + } + } + other_request => { + self.handle_request(other_request).await?; + } + }, + // In case of NeedsMoreData error we just leave + // and wait for new data to be ingested + Err(SmtpError::NeedsMoreData { .. }) => break, + Err(e) => { + self.handle_error(e).await?; + self.counters.errors += 1; + } + } + } + SessionState::Data(rx) => { + // Check if the message already exceeds allowed size + if self.data.message.len() + bytes.len() > self.params.max_message_size { + state = SessionState::DataTooLarge(DummyDataReceiver::new_data(rx)); + continue; + } else if rx.ingest(&mut iter, &mut self.data.message) { + // The message is fully received, time to queue + self.queue_message().await?; + state = SessionState::default(); + } else { + // No end-of-message marker found yet + break; + } + } + SessionState::Bdat(rx) => { + if rx.ingest(&mut iter, &mut self.data.message) { + if self.can_accept_message().await? { + if rx.is_last { + self.queue_message().await?; + } else { + self.write(b"250 2.6.0 Chunk accepted.\r\n").await?; + } + } else { + self.data.message = Vec::with_capacity(0); + } + state = SessionState::default(); + } else { + // Still some bytes left in the chunk + break; + } + } + SessionState::RequestTooLarge(rx) => { + // If line-feed found - issue error, otherwise keep ingesting + if rx.ingest(&mut iter) { + self.write(b"554 5.3.4 Line is too long.\r\n").await?; + state = SessionState::default(); + self.counters.errors += 1; + } else { + break; + } + } + SessionState::DataTooLarge(rx) => { + // If end-of-message marker found - issue error, otherwise keep ingesting + if rx.ingest(&mut iter) { + self.message_too_big().await?; + state = SessionState::default(); + self.counters.errors += 1; + } else { + // No end-of-message marker found yet + break; + } + } + SessionState::None => unreachable!(), + } + } + self.state = state; + + Ok(SessionUpgrade::No) + } + + /// Drives the session forward + pub async fn handle( + &mut self, + shutdown_token: CancellationToken, + ) -> SessionResult { + let mut buf = vec![0; 8192]; + + if matches!(self.state, SessionState::Greeting) { + self.greeting().await?; + self.state = SessionState::default(); + } + + loop { + select! { + // Read from the client with a timeout + res = self.stream.read(&mut buf).timeout(self.params.timeout) => { + match res { + Ok(Ok(bytes_read)) => { + self.ingest(&buf[..bytes_read]).await?; + } + Ok(Err(e)) => { + return Err(e.into()); + } + Err(_) => { + self.write(b"221 2.0.0 Disconnecting due to inactivity.\r\n").await?; + return Err(SessionError::Timeout); + } + } + }, + + () = shutdown_token.cancelled() => { + break; + } + } + } + + Ok(SessionUpgrade::No) + } + + pub(crate) async fn ext_unsupported(&mut self, ext: &str) -> SessionResult<()> { + self.write(b"501 5.5.4 ").await?; + self.write(ext.as_bytes()).await?; + return self.write(b" extension is not supported.\r\n").await; + } + + pub(crate) async fn message_too_big(&mut self) -> SessionResult<()> { + let msg = format!( + "552 5.3.4 Message too big for, we accept up to {} bytes.\r\n", + self.params.max_message_size + ); + return self.write(msg.as_bytes()).await; + } + + async fn queue_message(&mut self) -> SessionResult<()> { + let id = Uuid::now_v7(); + + let message_size = self.data.message.len(); + let message = Message { + id, + ehlo_hostname: self.data.ehlo_hostname.take().unwrap(), + mail_from: self.data.mail_from.take().unwrap(), + rcpt_to: self.data.rcpt_to.drain().collect(), + body: std::mem::take(&mut self.data.message), + }; + + if let Err(e) = self.params.delivery_agent.deliver_mail(message).await { + let msg = match e { + DeliveryError::Permanent(v) => { + format!("550 5.5.0 Permanent delivery error: {v}") + } + DeliveryError::Temporary(v) => { + format!("450 4.5.0 Temporary delivery error: {v}") + } + }; + + self.write(msg.as_bytes()).await?; + self.reset_message(); + return Ok(()); + } + + self.write( + format!("250 2.0.0 Message ({message_size} bytes) queued with id {id}\r\n").as_bytes(), + ) + .await?; + + self.counters.messages_queued += 1; + self.reset_message(); + Ok(()) + } + + async fn can_accept_message(&mut self) -> SessionResult { + if self.counters.messages_queued >= self.params.max_messages_per_session { + self.write(b"452 4.4.5 Maximum number of messages per session exceeded.\r\n") + .await?; + return Err(SessionError::TooManyMessagesPerSession); + } else if self.data.rcpt_to.is_empty() { + self.write(b"503 5.5.1 RCPT TO is required first.\r\n") + .await?; + self.counters.errors += 1; + return Ok(false); + } + + Ok(true) + } + + /// Resets the message-related fields to initial state + fn reset_message(&mut self) { + self.data.mail_from = None; + self.data.rcpt_to.clear(); + self.data.message.clear(); + } + + pub async fn shutdown(&mut self) -> SessionResult<()> { + self.stream + .shutdown() + .timeout(Duration::from_secs(10)) + .await + .context("shutdown timed out")? + .context("shutdown failed")?; + + Ok(()) + } +} diff --git a/ic-bn-lib/src/smtp/mod.rs b/ic-bn-lib/src/smtp/mod.rs index caf342b..9a3fbfa 100644 --- a/ic-bn-lib/src/smtp/mod.rs +++ b/ic-bn-lib/src/smtp/mod.rs @@ -1,10 +1,17 @@ +use std::fmt::{Debug, Display}; + use async_trait::async_trait; +use fqdn::FQDN; +use itertools::Itertools; +use tracing::warn; +use uuid::Uuid; use crate::smtp::address::EmailAddress; pub mod address; pub mod ic; -pub mod session; +pub mod inbound; +pub mod server; /// Recipient resolution policy pub enum RecipientPolicy { @@ -24,11 +31,74 @@ pub enum RecipientResolveError { Other(String), } +/// Delivery error +#[derive(thiserror::Error, Debug)] +pub enum DeliveryError { + #[error("{0}")] + Temporary(String), + #[error("{0}")] + Permanent(String), +} + +/// Low-level E-Mail representation +#[derive(Debug)] +pub struct Message { + pub id: Uuid, + pub ehlo_hostname: FQDN, + pub mail_from: EmailAddress, + pub rcpt_to: Vec, + pub body: Vec, +} + +impl Display for Message { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ehlo: {}, from: {}, to: {}, msg: {}", + self.ehlo_hostname, + self.mail_from, + self.rcpt_to.iter().map(|x| x.to_string()).join(", "), + String::from_utf8_lossy(&self.body) + ) + } +} + /// Looks up the given recipient & applies `RecipientPolicy` policy #[async_trait] -pub trait ResolvesRecipient: Send + Sync { +pub trait ResolvesRecipient: Send + Sync + Debug { async fn resolve_recipient( &self, rcpt: &EmailAddress, ) -> Result; } + +/// Delivers the E-Mail message +#[async_trait] +pub trait DeliversMail: Send + Sync + Debug { + async fn deliver_mail(&self, message: Message) -> Result<(), DeliveryError>; +} + +#[derive(Debug)] +pub struct DummyRecipientResolver; + +#[async_trait] +impl ResolvesRecipient for DummyRecipientResolver { + async fn resolve_recipient( + &self, + rcpt: &EmailAddress, + ) -> Result { + warn!("DummyRecipientResolver: {rcpt}"); + Ok(RecipientPolicy::Accept) + } +} + +#[derive(Debug)] +pub struct DummyDeliveryAgent; + +#[async_trait] +impl DeliversMail for DummyDeliveryAgent { + async fn deliver_mail(&self, message: Message) -> Result<(), DeliveryError> { + warn!("DummyDeliveryAgent: {message}"); + Ok(()) + } +} diff --git a/ic-bn-lib/src/smtp/server.rs b/ic-bn-lib/src/smtp/server.rs new file mode 100644 index 0000000..d4b1873 --- /dev/null +++ b/ic-bn-lib/src/smtp/server.rs @@ -0,0 +1,94 @@ +use std::{fmt::Display, io, net::SocketAddr, sync::Arc, time::Duration}; + +use async_trait::async_trait; +use ic_bn_lib_common::{traits::Run, types::http::ListenerOpts}; +use tokio::{ + net::{TcpListener, TcpStream}, + select, +}; +use tokio_util::{sync::CancellationToken, task::TaskTracker, time::FutureExt}; +use tracing::{info, warn}; + +use crate::{ + network::listener::listen_tcp, + smtp::inbound::{SessionError, SessionParams, SessionResult, manager::SessionManager}, +}; + +/// Listens for new connections and creates sessions +pub struct Server { + listen_addr: SocketAddr, + listener: TcpListener, + params: Arc, + tracker: TaskTracker, +} + +impl Display for Server { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SMTPServer({})", self.listen_addr) + } +} + +impl Server { + pub fn new(listen_addr: SocketAddr, params: SessionParams) -> Result { + let listener = listen_tcp(listen_addr, ListenerOpts::default())?; + + Ok(Self { + listen_addr, + listener, + params: Arc::new(params), + tracker: TaskTracker::new(), + }) + } + + async fn handle_accept( + &self, + res: io::Result<(TcpStream, SocketAddr)>, + token: &CancellationToken, + ) { + match res { + Ok((stream, addr)) => { + info!("{self}: New connection from {addr}"); + + let (manager, params, token) = + (SessionManager, self.params.clone(), token.child_token()); + + self.tracker.spawn(async move { + manager.handle_connection(stream, addr, params, token).await; + }); + } + + Err(e) => { + warn!("{self}: Unable to accept connection: {e:#}"); + tokio::time::sleep(Duration::from_millis(50)).await; + } + } + } + + pub async fn serve(&self, token: CancellationToken) -> SessionResult<()> { + loop { + select! { + res = self.listener.accept() => { + self.handle_accept(res, &token).await; + } + + () = token.cancelled() => { + self.tracker.close(); + if self.tracker.wait().timeout(Duration::from_secs(60)).await.is_err() { + warn!("{self}: Timed out waiting for connections to close"); + } + break; + } + } + } + + Ok(()) + } +} + +#[async_trait] +impl Run for Server { + async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> { + self.serve(token).await?; + Ok(()) + } +} diff --git a/ic-bn-lib/src/smtp/session/mail_from.rs b/ic-bn-lib/src/smtp/session/mail_from.rs deleted file mode 100644 index 0aba1a5..0000000 --- a/ic-bn-lib/src/smtp/session/mail_from.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::{borrow::Cow, io, str::FromStr}; - -use fqdn::FQDN; -use mail_auth::{IprevResult, Parameters}; -use smtp_proto::{ - EXT_8BIT_MIME, EXT_ENHANCED_STATUS_CODES, EXT_SMTP_UTF8, EXT_START_TLS, EhloResponse, MailFrom, -}; - -use crate::{ - network::AsyncReadWrite, - smtp::{address::EmailAddress, session::Session}, -}; - -impl Session { - /// Handles MAIL FROM command - pub async fn handle_mail_from(&mut self, from: MailFrom>) -> io::Result<()> { - if self.data.helo_hostname.is_none() { - return self - .write(b"503 5.5.1 Polite people say EHLO first.\r\n") - .await; - } - - if self.data.mail_from.is_some() { - return self - .write(b"503 5.5.1 Multiple MAIL FROM commands not allowed.\r\n") - .await; - } - - // Validate address - let Ok(address) = EmailAddress::from_str(&from.address) else { - return self - .write(b"550 5.7.1 Sender address is incorrect.\r\n") - .await; - }; - - // Validate reverse IP if configured - if self.params.verify_reverse_ip { - let result = self - .authenticator - .verify_iprev(Parameters::from(self.remote_ip)) - .await - .result; - - if !matches!(result, IprevResult::Pass) { - let message = if matches!(result, IprevResult::TempError(_)) { - &b"451 4.7.25 Temporary error validating reverse DNS.\r\n"[..] - } else { - &b"550 5.7.25 Reverse DNS validation failed.\r\n"[..] - }; - - return self.write(message).await; - } - } - - self.data.mail_from = Some(address); - Ok(()) - } -} diff --git a/ic-bn-lib/src/smtp/session/mod.rs b/ic-bn-lib/src/smtp/session/mod.rs deleted file mode 100644 index 57370fb..0000000 --- a/ic-bn-lib/src/smtp/session/mod.rs +++ /dev/null @@ -1,196 +0,0 @@ -pub mod helo; -pub mod mail_from; -pub mod rcpt_to; - -use std::{io, net::IpAddr, sync::Arc, time::Duration}; - -use ahash::AHashSet; -use fqdn::FQDN; -use mail_auth::MessageAuthenticator; -use rustls::ServerConfig; -use smtp_proto::request::receiver::{ - BdatReceiver, DataReceiver, DummyDataReceiver, DummyLineReceiver, LineReceiver, RequestReceiver, -}; -use strum::Display; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - select, - time::{error::Elapsed, timeout}, -}; -use tokio_util::sync::CancellationToken; -use tracing::{info, warn}; -use uuid::Uuid; - -use crate::{ - network::AsyncReadWrite, - smtp::{ResolvesRecipient, address::EmailAddress}, -}; - -/// SMTP session state -#[derive(Default, Display)] -pub enum SessionState { - #[default] - Init, - Request(RequestReceiver), - Data(DataReceiver), - Done, -} - -/// SMTP session data -#[derive(Debug, Default)] -pub struct SessionData { - pub helo_hostname: Option, - pub mail_from: Option, - pub rcpt_to: AHashSet, -} - -/// SMTP session parameters -#[derive(Debug)] -pub struct SessionParams { - pub max_message_size: u64, - pub max_recipients: usize, - pub max_session_duration: Duration, - pub verify_ehlo_hostname: bool, - pub verify_spf: bool, - pub verify_reverse_ip: bool, - pub helo_delay: Duration, - pub timeout: Duration, -} - -impl Default for SessionParams { - fn default() -> Self { - Self { - max_message_size: 10 * 1024 * 1024, - max_recipients: 5, - max_session_duration: Duration::from_secs(600), - verify_ehlo_hostname: false, - verify_reverse_ip: false, - verify_spf: false, - helo_delay: Duration::from_secs(3), - timeout: Duration::from_secs(30), - } - } -} - -/// SMTP Session -pub struct Session { - id: Uuid, - hostname: String, - remote_ip: IpAddr, - stream: S, - state: SessionState, - data: SessionData, - params: SessionParams, - authenticator: Arc, - recipient_resolver: Arc, - in_starttls: bool, - tls_config: Option, - shutdown_token: CancellationToken, -} - -#[allow(clippy::too_many_arguments)] -impl Session { - pub fn new( - hostname: String, - remote_ip: IpAddr, - stream: S, - params: SessionParams, - authenticator: Arc, - recipient_resolver: Arc, - tls_config: Option, - shutdown_token: CancellationToken, - ) -> Self { - Self { - id: Uuid::now_v7(), - hostname, - remote_ip, - stream, - state: SessionState::default(), - data: SessionData::default(), - params, - authenticator, - recipient_resolver, - in_starttls: false, - tls_config, - shutdown_token, - } - } - - /// Writes given bytes to the session & flushes the buffer - pub async fn write(&mut self, bytes: &[u8]) -> io::Result<()> { - self.stream.write_all(bytes).await?; - self.stream.flush().await?; - Ok(()) - } - - async fn ingest(&mut self, bytes: &[u8]) -> io::Result<()> { - match &self.state { - SessionState::Init => { - // If we have HELO delay configured - try to read from the stream for up to this duration. - // The client needs to wait silently until we send our greeting. - // If something comes in - then the client isn't respecting the protocol and we consider him malicious. - if self.params.helo_delay != Duration::ZERO { - if timeout(self.params.helo_delay, self.stream.read_u8()) - .await - .is_ok() - { - self.write(b"").await?; - } - } - - let greeting = format!("220 {} ESMTP IC SMTP Gateway\r\n", self.hostname); - self.write(greeting.as_bytes()).await?; - - self.state = SessionState::Request(RequestReceiver::default()) - } - - SessionState::Request(rx) => {} - SessionState::Data(rx) => {} - SessionState::Done => { - self.stream.shutdown().await.ok(); - } - } - - Ok(()) - } - - pub async fn read( - &mut self, - buf: &mut [u8], - res: Result, Elapsed>, - ) -> io::Result<()> { - match res { - Ok(Ok(bytes_read)) => { - self.ingest(&buf[..bytes_read]).await?; - } - Ok(Err(e)) => { - return Err(e); - } - Err(e) => return Err(io::Error::other(e)), - } - - Ok(()) - } - - /// Drives the session forward - pub async fn handle(&mut self) { - let mut buf = vec![0; 8192]; - - loop { - select! { - res = timeout(self.params.timeout, self.stream.read(&mut buf)) => { - if let Err(e) = self.read(&mut buf, res).await { - info!("Session error: {e:#}"); - break; - }; - }, - - () = self.shutdown_token.cancelled() => { - break; - } - } - } - - self.stream.shutdown().await.ok(); - } -} diff --git a/ic-bn-lib/tools/smtp_server.rs b/ic-bn-lib/tools/smtp_server.rs new file mode 100644 index 0000000..41d99a1 --- /dev/null +++ b/ic-bn-lib/tools/smtp_server.rs @@ -0,0 +1,20 @@ +use std::{net::SocketAddr, str::FromStr, time::Duration}; + +use ic_bn_lib::smtp::{inbound::SessionParams, server::Server}; +use tokio_util::sync::CancellationToken; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt().init(); + + let mut params = SessionParams::default(); + params.helo_delay = Some(Duration::from_secs(1)); + params.hostname = "mail.icp.net".into(); + //params.max_message_size = 16; + //params.max_session_duration = Duration::from_secs(30); + //params.max_session_data = 16; + + let server = Server::new(SocketAddr::from_str("127.0.0.1:1025").unwrap(), params).unwrap(); + + server.serve(CancellationToken::new()).await.unwrap(); +} From 54a1a870fe6a091beb44a961b73628052cf5847a Mon Sep 17 00:00:00 2001 From: Igor Novgorodov Date: Sat, 16 May 2026 14:03:08 +0200 Subject: [PATCH 4/8] Add tests --- ic-bn-lib/Cargo.toml | 1 + ic-bn-lib/src/smtp/inbound/ehlo.rs | 10 +- ic-bn-lib/src/smtp/inbound/mail_from.rs | 12 +- ic-bn-lib/src/smtp/inbound/manager.rs | 20 ++- ic-bn-lib/src/smtp/inbound/mod.rs | 224 ++++++++++++++++++++++-- ic-bn-lib/src/smtp/inbound/rcpt_to.rs | 10 +- ic-bn-lib/src/smtp/inbound/session.rs | 95 +++++----- ic-bn-lib/src/smtp/mod.rs | 5 +- ic-bn-lib/src/smtp/server.rs | 18 +- ic-bn-lib/tools/smtp_server.rs | 10 +- 10 files changed, 307 insertions(+), 98 deletions(-) diff --git a/ic-bn-lib/Cargo.toml b/ic-bn-lib/Cargo.toml index a47330f..0fcaa72 100644 --- a/ic-bn-lib/Cargo.toml +++ b/ic-bn-lib/Cargo.toml @@ -149,6 +149,7 @@ mock-io = { version = "0.3.2", features = ["full"] } # Do not upgrade unless you want to upgrade `rand` too rand_regex = "=0.17.0" tempfile = "3.20.0" +tokio-test = "0.4" [[bench]] name = "vector" diff --git a/ic-bn-lib/src/smtp/inbound/ehlo.rs b/ic-bn-lib/src/smtp/inbound/ehlo.rs index a261e36..580a61a 100644 --- a/ic-bn-lib/src/smtp/inbound/ehlo.rs +++ b/ic-bn-lib/src/smtp/inbound/ehlo.rs @@ -28,8 +28,8 @@ impl Session { }; // Check if EHLO hostname resolves if configured - if self.params.verify_ehlo_hostname { - match self.params.authenticator.resolver().lookup_ip(host).await { + if self.cfg.verify_ehlo_hostname { + match self.cfg.authenticator.resolver().lookup_ip(host).await { Ok(v) => { if v.iter().next().is_none() { return self @@ -57,13 +57,13 @@ impl Session { self.data.mail_from = None; self.data.rcpt_to = AHashSet::new(); - let mut response = EhloResponse::new(self.params.hostname.as_str()); + let mut response = EhloResponse::new(self.cfg.hostname.as_str()); response.capabilities = EXT_ENHANCED_STATUS_CODES | EXT_8BIT_MIME | EXT_SMTP_UTF8 | EXT_CHUNKING; - response.size = self.params.max_message_size; + response.size = self.cfg.max_message_size; // Send STARTTLS cap only if we support TLS & we're not already in TLS mode - if self.tls_info.is_none() && self.params.tls_config.is_some() { + if self.tls_info.is_none() && self.cfg.tls_enabled() { response.capabilities |= EXT_START_TLS; } diff --git a/ic-bn-lib/src/smtp/inbound/mail_from.rs b/ic-bn-lib/src/smtp/inbound/mail_from.rs index a45fa16..46bc8e2 100644 --- a/ic-bn-lib/src/smtp/inbound/mail_from.rs +++ b/ic-bn-lib/src/smtp/inbound/mail_from.rs @@ -34,7 +34,7 @@ impl Session { return self.ext_unsupported("MT-PRIORITY").await; } - if from.size > 0 && from.size > self.params.max_message_size { + if from.size > self.cfg.max_message_size { return self.message_too_big().await; } @@ -54,9 +54,9 @@ impl Session { }; // Validate reverse IP if configured - if self.params.verify_reverse_ip { + if self.cfg.verify_reverse_ip { let result = self - .params + .cfg .authenticator .verify_iprev(Parameters::from(self.remote_ip)) .await @@ -73,14 +73,14 @@ impl Session { } } - if self.params.verify_spf { + if self.cfg.verify_spf { let output = self - .params + .cfg .authenticator .verify_spf(SpfParameters::verify_mail_from( self.remote_ip, &helo_hostname.to_string(), - &self.params.hostname, + &self.cfg.hostname, &from.address, )) .await; diff --git a/ic-bn-lib/src/smtp/inbound/manager.rs b/ic-bn-lib/src/smtp/inbound/manager.rs index aa33234..5a1c0e2 100644 --- a/ic-bn-lib/src/smtp/inbound/manager.rs +++ b/ic-bn-lib/src/smtp/inbound/manager.rs @@ -8,7 +8,9 @@ use tracing::{debug, info}; use crate::{ network::{AsyncReadWrite, tls_handshake}, - smtp::inbound::{Session, SessionData, SessionParams, SessionResult, SessionUpgrade}, + smtp::inbound::{ + Session, SessionConfig, SessionData, SessionResult, SessionTlsMode, SessionUpgrade, + }, }; /// Manages the lifetime of a single SMTP session. @@ -23,7 +25,7 @@ impl SessionManager { &self, stream: S, remote_addr: SocketAddr, - params: Arc, + params: Arc, shutdown_token: CancellationToken, ) { let mut session = Session::new(remote_addr.ip(), stream, params); @@ -63,10 +65,14 @@ impl SessionManager { impl Session { /// Converts the plain-text session into a TLS one by doing a TLS handshake pub async fn into_tls(self) -> SessionResult>> { - // SAFETY: Code makes sure that we end up here only if tls_config is Some. - // If we ever panic here - it should mean that the core logic is flawed. - let (stream, tls_info) = - tls_handshake(self.params.tls_config.clone().unwrap(), self.stream).await?; + // SAFETY: We should end up here only if TLS is enabled. + // It's better to panic otherwise. + let tls_config = match &self.cfg.tls_mode { + SessionTlsMode::Allowed(v) | SessionTlsMode::Required(v) => v.clone(), + SessionTlsMode::Disabled => unreachable!(), + }; + + let (stream, tls_info) = tls_handshake(tls_config, self.stream).await?; Ok(Session { id: self.id, @@ -78,7 +84,7 @@ impl Session { // https://datatracker.ietf.org/doc/html/rfc3207#section-4.2 data: SessionData::default(), counters: self.counters, - params: self.params, + cfg: self.cfg, tls_info: Some(tls_info), }) } diff --git a/ic-bn-lib/src/smtp/inbound/mod.rs b/ic-bn-lib/src/smtp/inbound/mod.rs index 2bafdd9..2bc0ecf 100644 --- a/ic-bn-lib/src/smtp/inbound/mod.rs +++ b/ic-bn-lib/src/smtp/inbound/mod.rs @@ -5,6 +5,7 @@ pub mod rcpt_to; pub mod session; use std::{ + fmt::Display, io, net::IpAddr, sync::Arc, @@ -40,21 +41,24 @@ pub enum SessionError { #[error("Timed out")] Timeout, #[error("{0}")] - ProtocolError(String), - #[error("{0}")] SmtpError(#[from] SmtpError), #[error("Session terminated by client (QUIT)")] Quit, + #[error("Client is sending before greeting")] + SendsBeforeGreeting, #[error("Too many messages per session")] TooManyMessagesPerSession, - #[error("Session transfer quota ({0}) was exceeded")] + #[error("Session transfer quota ({0} bytes) was exceeded")] TransferQuotaExceeded(usize), #[error("Session TTL ({0}s) was exceeded")] TtlExceeded(u64), + #[error("Too many errors")] + TooManyErrors, #[error("{0}")] Other(#[from] anyhow::Error), } +/// Indicates if a session needs to be upgraded to TLS #[derive(Debug, Clone, PartialEq, Eq)] pub enum SessionUpgrade { No, @@ -63,8 +67,15 @@ pub enum SessionUpgrade { pub type SessionResult = Result; -/// SMTP session parameters -pub struct SessionParams { +/// Session TLS mode +pub enum SessionTlsMode { + Disabled, + Allowed(Arc), + Required(Arc), +} + +/// SMTP session config +pub struct SessionConfig { pub hostname: String, pub max_message_size: usize, pub max_recipients: usize, @@ -73,17 +84,18 @@ pub struct SessionParams { pub max_errors: usize, pub max_messages_per_session: usize, pub verify_ehlo_hostname: bool, - pub verify_spf: bool, + pub verify_sender_domain: bool, pub verify_reverse_ip: bool, + pub verify_spf: bool, pub helo_delay: Option, pub timeout: Duration, - pub tls_config: Option>, + pub tls_mode: SessionTlsMode, pub authenticator: Arc, pub recipient_resolver: Arc, pub delivery_agent: Arc, } -impl Default for SessionParams { +impl Default for SessionConfig { fn default() -> Self { Self { hostname: "".into(), @@ -95,10 +107,11 @@ impl Default for SessionParams { max_messages_per_session: 5, verify_ehlo_hostname: false, verify_reverse_ip: false, + verify_sender_domain: false, verify_spf: false, helo_delay: None, timeout: Duration::from_secs(30), - tls_config: None, + tls_mode: SessionTlsMode::Disabled, // SAFETY: this never fails authenticator: Arc::new(MessageAuthenticator::new_cloudflare().unwrap()), recipient_resolver: Arc::new(DummyRecipientResolver), @@ -107,15 +120,35 @@ impl Default for SessionParams { } } +impl SessionConfig { + pub const fn tls_enabled(&self) -> bool { + matches!( + self.tls_mode, + SessionTlsMode::Allowed(_) | SessionTlsMode::Required(_) + ) + } + + pub const fn tls_required(&self) -> bool { + matches!(self.tls_mode, SessionTlsMode::Required(_)) + } +} + /// SMTP session state #[derive(Display)] pub enum SessionState { + /// Need to send greeting Greeting, + /// Default - command/response Request(RequestReceiver), + /// ASCII data reception Data(DataReceiver), + /// Binary data reception Bdat(BdatReceiver), + /// Too long request received - blackhole RequestTooLarge(DummyLineReceiver), + /// Too large data received - blackhole DataTooLarge(DummyDataReceiver), + /// Dummy None, } @@ -125,7 +158,7 @@ impl Default for SessionState { } } -/// SMTP session data +/// SMTP dynamic session data #[derive(Debug, Default)] pub struct SessionData { pub ehlo_hostname: Option, @@ -162,6 +195,175 @@ pub struct Session { state: SessionState, data: SessionData, counters: SessionCounters, - params: Arc, + cfg: Arc, tls_info: Option, } + +impl Display for Session { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SMTP/Session({})", self.remote_ip) + } +} + +impl Session { + pub fn new(remote_ip: IpAddr, stream: S, cfg: Arc) -> Self { + Self { + id: Uuid::now_v7(), + remote_ip, + stream, + state: SessionState::Greeting, + data: SessionData::default(), + counters: SessionCounters::new(cfg.max_session_duration), + cfg, + tls_info: None, + } + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use tokio_util::sync::CancellationToken; + + use super::*; + + fn create_session(stream: S, helo_delay: Option) -> Session { + let mut cfg = SessionConfig::default(); + cfg.hostname = "test".into(); + cfg.max_errors = 3; + cfg.max_message_size = 512; + cfg.helo_delay = helo_delay; + cfg.max_messages_per_session = 3; + cfg.max_session_data = 1024; + cfg.max_recipients = 3; + + Session::new(IpAddr::from_str("1.1.1.1").unwrap(), stream, Arc::new(cfg)) + } + + fn create_basic_stream() -> tokio_test::io::Builder { + let mut builder = tokio_test::io::Builder::new(); + + builder.write(b"220 test ESMTP IC SMTP Gateway\r\n") + .read(b"EHLO foo.bar\r\n") + .write(b"250-test you had me at EHLO\r\n250-SMTPUTF8\r\n250-ENHANCEDSTATUSCODES\r\n250-CHUNKING\r\n250 8BITMIME\r\n"); + builder + } + + fn stream_send_message(b: &mut tokio_test::io::Builder) { + b.read(b"MAIL FROM:\r\n") + .write(b"250 2.1.0 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"DATA\r\n") + .write(b"354 Start mail input; end with .\r\n") + .read(b"foobarmessage\r\n.\r\n") + .write(b"250 2.0.0 Message (13 bytes) queued with id 00000000-0000-0000-0000-000000000000\r\n"); + } + + #[tokio::test] + async fn test_basic_session() { + let mut builder = create_basic_stream(); + stream_send_message(&mut builder); + let stream = builder + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + + #[tokio::test] + async fn test_client_sends_before_greeting() { + let stream = tokio_test::io::Builder::new() + .read(b"EHLO foo.bar\r\n") + .write(b"501 5.7.1 Client sent command before greeting banner.\r\n") + .build(); + + let mut session = create_session(stream, Some(Duration::from_millis(100))); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::SendsBeforeGreeting + )); + } + + #[tokio::test] + async fn test_max_recipients() { + let stream = create_basic_stream() + .read(b"MAIL FROM:\r\n") + .write(b"250 2.1.0 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"455 4.5.3 Too many recipients.\r\n") + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + + #[tokio::test] + async fn test_max_message_size() { + let stream = create_basic_stream() + .read(b"MAIL FROM:\r\n") + .write(b"250 2.1.0 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"DATA\r\n") + .write(b"354 Start mail input; end with .\r\n") + .read(format!("{}\r\n.\r\n", "1".repeat(513)).as_bytes()) + .write(b"552 5.3.4 Message too big for, we accept up to 512 bytes.\r\n") + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + + #[tokio::test] + async fn test_max_messages_per_session() { + let mut builder = create_basic_stream(); + stream_send_message(&mut builder); + stream_send_message(&mut builder); + stream_send_message(&mut builder); + + let stream = builder + .read(b"MAIL FROM:\r\n") + .write(b"250 2.1.0 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"DATA\r\n") + .write(b"452 4.4.5 Maximum number of messages per session exceeded.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::TooManyMessagesPerSession + )); + } +} diff --git a/ic-bn-lib/src/smtp/inbound/rcpt_to.rs b/ic-bn-lib/src/smtp/inbound/rcpt_to.rs index 188eff0..ff199d2 100644 --- a/ic-bn-lib/src/smtp/inbound/rcpt_to.rs +++ b/ic-bn-lib/src/smtp/inbound/rcpt_to.rs @@ -22,10 +22,6 @@ impl Session { .await; } - if self.data.rcpt_to.len() >= self.params.max_recipients { - return self.write(b"455 4.5.3 Too many recipients.\r\n").await; - } - // Check if DSN-related stuff was requested if (to.flags & (RCPT_NOTIFY_DELAY | RCPT_NOTIFY_NEVER | RCPT_NOTIFY_SUCCESS | RCPT_NOTIFY_FAILURE)) @@ -43,8 +39,12 @@ impl Session { return self.write(b"250 2.1.5 OK\r\n").await; } + if self.data.rcpt_to.len() >= self.cfg.max_recipients { + return self.write(b"455 4.5.3 Too many recipients.\r\n").await; + } + match self - .params + .cfg .recipient_resolver .resolve_recipient(&address) .await diff --git a/ic-bn-lib/src/smtp/inbound/session.rs b/ic-bn-lib/src/smtp/inbound/session.rs index ea0771f..6d44c8e 100644 --- a/ic-bn-lib/src/smtp/inbound/session.rs +++ b/ic-bn-lib/src/smtp/inbound/session.rs @@ -1,8 +1,5 @@ use std::{ borrow::Cow, - fmt::Display, - net::IpAddr, - sync::Arc, time::{Duration, Instant}, }; @@ -16,41 +13,18 @@ use tokio::{ select, }; use tokio_util::{sync::CancellationToken, time::FutureExt}; -use tracing::info; use uuid::Uuid; use crate::{ network::AsyncReadWrite, smtp::{ DeliveryError, Message, - inbound::{ - Session, SessionCounters, SessionData, SessionError, SessionParams, SessionResult, - SessionState, SessionUpgrade, - }, + inbound::{Session, SessionError, SessionResult, SessionState, SessionUpgrade}, }, }; -impl Display for Session { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "SMTP({}): ", self.remote_ip) - } -} - #[allow(clippy::too_many_arguments)] impl Session { - pub fn new(remote_ip: IpAddr, stream: S, params: Arc) -> Self { - Self { - id: Uuid::now_v7(), - remote_ip, - stream, - state: SessionState::Greeting, - data: SessionData::default(), - counters: SessionCounters::new(params.max_session_duration), - params, - tls_info: None, - } - } - /// Writes given bytes to the session & flushes the buffer pub async fn write(&mut self, bytes: &[u8]) -> SessionResult<()> { self.stream.write_all(bytes).await?; @@ -64,21 +38,22 @@ impl Session { // The client needs to wait silently until we send our greeting. // If something comes in - then the client isn't respecting the protocol, // we consider him malicious and drop the connection. - if let Some(v) = self.params.helo_delay { - match self.stream.read_u8().timeout(v).await { - Ok(Ok(_)) => { - self.write(b"501 5.7.1 Client sent command before greeting banner.\r\n") - .await?; - return Err(SessionError::ProtocolError( - "Command before greeting banner".into(), - )); + if let Some(v) = self.cfg.helo_delay { + let mut buf = vec![0; 128]; + match self.stream.read(&mut buf).timeout(v).await { + Ok(Ok(bytes_read)) => { + if bytes_read > 0 { + self.write(b"501 5.7.1 Client sent command before greeting banner.\r\n") + .await?; + return Err(SessionError::SendsBeforeGreeting); + } } Ok(Err(e)) => return Err(e.into()), Err(_) => {} } } - let greeting = format!("220 {} ESMTP IC SMTP Gateway\r\n", self.params.hostname); + let greeting = format!("220 {} ESMTP IC SMTP Gateway\r\n", self.cfg.hostname); self.write(greeting.as_bytes()).await?; Ok(()) @@ -151,28 +126,34 @@ impl Session { } async fn ingest(&mut self, bytes: &[u8]) -> SessionResult { - // Check if we alread are over transfer quota - if self.counters.bytes_ingested + bytes.len() >= self.params.max_session_data { + // Check if we are over session transfer quota + if self.counters.bytes_ingested + bytes.len() >= self.cfg.max_session_data { self.write(b"452 4.7.28 Session transfer quota exceeded.\r\n") .await?; return Err(SessionError::TransferQuotaExceeded( - self.params.max_session_data, + self.cfg.max_session_data, )); } - // Check if we already are over session time quota + // Check if we are over session time quota if Instant::now() > self.counters.valid_until { self.write(b"452 4.3.2 Session open for too long.\r\n") .await?; return Err(SessionError::TtlExceeded( - self.params.max_session_duration.as_secs(), + self.cfg.max_session_duration.as_secs(), )); } + // Check if we are over error limit + if self.counters.errors > self.cfg.max_errors { + self.write(b"452 4.3.2 Too many errors.\r\n").await?; + return Err(SessionError::TooManyErrors); + } + self.counters.bytes_ingested += bytes.len(); let mut iter = bytes.iter(); - // We can't take mutable ref to self.state & self at the same time, so we extract state - // temporarily + // We can't take mutable ref to self.state & self at the same time, + // so we extract state temporarily. let mut state = std::mem::replace(&mut self.state, SessionState::None); loop { @@ -199,8 +180,9 @@ impl Session { chunk_size, is_last, } => { + // Check if we will be past max message limit with this chunk state = if self.data.message.len() + chunk_size - > self.params.max_message_size + > self.cfg.max_message_size { SessionState::DataTooLarge(DummyDataReceiver::new_bdat( chunk_size, @@ -220,7 +202,7 @@ impl Session { if self.tls_info.is_some() { self.write(b"504 5.7.4 Already in TLS mode.\r\n").await?; self.counters.errors += 1; - } else if self.params.tls_config.is_none() { + } else if !self.cfg.tls_enabled() { self.write(b"502 5.7.0 TLS not available.\r\n").await?; self.counters.errors += 1; } else { @@ -243,7 +225,7 @@ impl Session { } SessionState::Data(rx) => { // Check if the message already exceeds allowed size - if self.data.message.len() + bytes.len() > self.params.max_message_size { + if self.data.message.len() + bytes.len() > self.cfg.max_message_size { state = SessionState::DataTooLarge(DummyDataReceiver::new_data(rx)); continue; } else if rx.ingest(&mut iter, &mut self.data.message) { @@ -279,6 +261,7 @@ impl Session { state = SessionState::default(); self.counters.errors += 1; } else { + // No line-feed found yet break; } } @@ -316,7 +299,7 @@ impl Session { loop { select! { // Read from the client with a timeout - res = self.stream.read(&mut buf).timeout(self.params.timeout) => { + res = self.stream.read(&mut buf).timeout(self.cfg.timeout) => { match res { Ok(Ok(bytes_read)) => { self.ingest(&buf[..bytes_read]).await?; @@ -349,24 +332,30 @@ impl Session { pub(crate) async fn message_too_big(&mut self) -> SessionResult<()> { let msg = format!( "552 5.3.4 Message too big for, we accept up to {} bytes.\r\n", - self.params.max_message_size + self.cfg.max_message_size ); return self.write(msg.as_bytes()).await; } async fn queue_message(&mut self) -> SessionResult<()> { + #[cfg(not(test))] let id = Uuid::now_v7(); + #[cfg(test)] + let id = Uuid::nil(); let message_size = self.data.message.len(); + + // SAFETY: Code makes sure these are all Some(). + // It's better to panic in tests if they are not. let message = Message { id, - ehlo_hostname: self.data.ehlo_hostname.take().unwrap(), + ehlo_hostname: self.data.ehlo_hostname.clone().unwrap(), mail_from: self.data.mail_from.take().unwrap(), rcpt_to: self.data.rcpt_to.drain().collect(), body: std::mem::take(&mut self.data.message), }; - if let Err(e) = self.params.delivery_agent.deliver_mail(message).await { + if let Err(e) = self.cfg.delivery_agent.deliver_mail(message).await { let msg = match e { DeliveryError::Permanent(v) => { format!("550 5.5.0 Permanent delivery error: {v}") @@ -392,7 +381,7 @@ impl Session { } async fn can_accept_message(&mut self) -> SessionResult { - if self.counters.messages_queued >= self.params.max_messages_per_session { + if self.counters.messages_queued >= self.cfg.max_messages_per_session { self.write(b"452 4.4.5 Maximum number of messages per session exceeded.\r\n") .await?; return Err(SessionError::TooManyMessagesPerSession); @@ -406,7 +395,7 @@ impl Session { Ok(true) } - /// Resets the message-related fields to initial state + /// Resets the message-related fields to their initial state fn reset_message(&mut self) { self.data.mail_from = None; self.data.rcpt_to.clear(); @@ -414,6 +403,8 @@ impl Session { } pub async fn shutdown(&mut self) -> SessionResult<()> { + self.write(b"421 4.3.0 Server shutting down.\r\n").await?; + self.stream .shutdown() .timeout(Duration::from_secs(10)) diff --git a/ic-bn-lib/src/smtp/mod.rs b/ic-bn-lib/src/smtp/mod.rs index 9a3fbfa..1ffa1d2 100644 --- a/ic-bn-lib/src/smtp/mod.rs +++ b/ic-bn-lib/src/smtp/mod.rs @@ -54,11 +54,14 @@ impl Display for Message { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "ehlo: {}, from: {}, to: {}, msg: {}", + "id: {}, ehlo: {}, from: {}, to: {}, msg: {}", + self.id, self.ehlo_hostname, self.mail_from, self.rcpt_to.iter().map(|x| x.to_string()).join(", "), String::from_utf8_lossy(&self.body) + .replace('\n', "\\n") + .replace('\r', "\\r") ) } } diff --git a/ic-bn-lib/src/smtp/server.rs b/ic-bn-lib/src/smtp/server.rs index d4b1873..4fa2bfc 100644 --- a/ic-bn-lib/src/smtp/server.rs +++ b/ic-bn-lib/src/smtp/server.rs @@ -11,29 +11,35 @@ use tracing::{info, warn}; use crate::{ network::listener::listen_tcp, - smtp::inbound::{SessionError, SessionParams, SessionResult, manager::SessionManager}, + smtp::inbound::{SessionConfig, SessionError, SessionResult, manager::SessionManager}, }; -/// Listens for new connections and creates sessions +/// Listens for new SMTP connections and creates sessions pub struct Server { listen_addr: SocketAddr, listener: TcpListener, - params: Arc, + params: Arc, tracker: TaskTracker, } impl Display for Server { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "SMTPServer({})", self.listen_addr) + write!(f, "SMTP/Server({})", self.listen_addr) } } impl Server { - pub fn new(listen_addr: SocketAddr, params: SessionParams) -> Result { + pub fn new(listen_addr: SocketAddr, cfg: SessionConfig) -> Result { let listener = listen_tcp(listen_addr, ListenerOpts::default())?; + Self::new_with_listener(listener, cfg) + } + pub fn new_with_listener( + listener: TcpListener, + params: SessionConfig, + ) -> Result { Ok(Self { - listen_addr, + listen_addr: listener.local_addr()?, listener, params: Arc::new(params), tracker: TaskTracker::new(), diff --git a/ic-bn-lib/tools/smtp_server.rs b/ic-bn-lib/tools/smtp_server.rs index 41d99a1..38f7f12 100644 --- a/ic-bn-lib/tools/smtp_server.rs +++ b/ic-bn-lib/tools/smtp_server.rs @@ -1,20 +1,20 @@ use std::{net::SocketAddr, str::FromStr, time::Duration}; -use ic_bn_lib::smtp::{inbound::SessionParams, server::Server}; +use ic_bn_lib::smtp::{inbound::SessionConfig, server::Server}; use tokio_util::sync::CancellationToken; #[tokio::main] async fn main() { tracing_subscriber::fmt().init(); - let mut params = SessionParams::default(); - params.helo_delay = Some(Duration::from_secs(1)); - params.hostname = "mail.icp.net".into(); + let mut cfg = SessionConfig::default(); + cfg.helo_delay = Some(Duration::from_secs(1)); + cfg.hostname = "mail.icp.net".into(); //params.max_message_size = 16; //params.max_session_duration = Duration::from_secs(30); //params.max_session_data = 16; - let server = Server::new(SocketAddr::from_str("127.0.0.1:1025").unwrap(), params).unwrap(); + let server = Server::new(SocketAddr::from_str("127.0.0.1:1025").unwrap(), cfg).unwrap(); server.serve(CancellationToken::new()).await.unwrap(); } From fb9e07143ebf9f788778b0a40fb9f93df6528108 Mon Sep 17 00:00:00 2001 From: Igor Novgorodov Date: Sun, 17 May 2026 11:55:27 +0200 Subject: [PATCH 5/8] SMTP reworks, add tests --- ic-bn-lib/src/smtp/inbound/ehlo.rs | 36 +++-- ic-bn-lib/src/smtp/inbound/mail_from.rs | 27 ++-- ic-bn-lib/src/smtp/inbound/mod.rs | 183 ++++++++++++++++++++++-- ic-bn-lib/src/smtp/inbound/rcpt_to.rs | 18 +-- ic-bn-lib/src/smtp/inbound/session.rs | 181 ++++++++++++++--------- ic-bn-lib/src/smtp/mod.rs | 4 +- ic-bn-lib/tools/smtp_server.rs | 6 +- 7 files changed, 342 insertions(+), 113 deletions(-) diff --git a/ic-bn-lib/src/smtp/inbound/ehlo.rs b/ic-bn-lib/src/smtp/inbound/ehlo.rs index 580a61a..88612f1 100644 --- a/ic-bn-lib/src/smtp/inbound/ehlo.rs +++ b/ic-bn-lib/src/smtp/inbound/ehlo.rs @@ -1,13 +1,11 @@ use std::str::FromStr; -use ahash::AHashSet; use fqdn::FQDN; use mail_auth::hickory_resolver::proto::ProtoErrorKind; use smtp_proto::{ EXT_8BIT_MIME, EXT_CHUNKING, EXT_ENHANCED_STATUS_CODES, EXT_SMTP_UTF8, EXT_START_TLS, EhloResponse, }; -use tracing::info; use crate::{ network::AsyncReadWrite, @@ -16,14 +14,22 @@ use crate::{ impl Session { /// Handles EHLO/HELO commands - pub async fn handle_ehlo(&mut self, host: &str) -> SessionResult<()> { + pub async fn handle_ehlo(&mut self, host: &str, extended: bool) -> SessionResult<()> { // Validate hostname let Ok(ehlo_hostname) = FQDN::from_str(host) else { - return self.write(b"550 5.5.0 Invalid EHLO hostname.\r\n").await; + return self.reply("550", "5.5.0", "Invalid EHLO hostname.").await; }; + + // If EHLO hostname is already set to the same value - just reply directly + if let Some(v) = &self.data.ehlo_hostname + && v == &ehlo_hostname + { + return self.send_ehlo(extended).await; + } + if ehlo_hostname.depth() < 2 { return self - .write(b"550 5.5.0 EHLO hostname must be an FQDN.\r\n") + .reply("550", "5.5.0", "EHLO hostname must be an FQDN.") .await; }; @@ -33,7 +39,7 @@ impl Session { Ok(v) => { if v.iter().next().is_none() { return self - .write(b"550 5.5.0 EHLO hostname not found in DNS.\r\n") + .reply("550", "5.5.0", "EHLO hostname not found in DNS.") .await; } } @@ -41,21 +47,29 @@ impl Session { Err(e) => { if matches!(e.kind(), ProtoErrorKind::NoRecordsFound(_)) { return self - .write(b"550 5.5.0 EHLO hostname not found in DNS.\r\n") + .reply("550", "5.5.0", "EHLO hostname not found in DNS.") .await; } - info!("{self}: Unable to lookup '{host}' in DNS: {e:#}"); return self - .write(b"451 4.7.25 Temporary error validating EHLO hostname.\r\n") + .reply("451", "4.7.25", "Temporary error validating EHLO hostname.") .await; } } } + self.reset_message(); self.data.ehlo_hostname = Some(ehlo_hostname); - self.data.mail_from = None; - self.data.rcpt_to = AHashSet::new(); + + return self.send_ehlo(extended).await; + } + + async fn send_ehlo(&mut self, extended: bool) -> SessionResult<()> { + if !extended { + return self + .write(format!("250 {} you had me at HELO\r\n", self.cfg.hostname).as_bytes()) + .await; + } let mut response = EhloResponse::new(self.cfg.hostname.as_str()); response.capabilities = diff --git a/ic-bn-lib/src/smtp/inbound/mail_from.rs b/ic-bn-lib/src/smtp/inbound/mail_from.rs index 46bc8e2..45bba6c 100644 --- a/ic-bn-lib/src/smtp/inbound/mail_from.rs +++ b/ic-bn-lib/src/smtp/inbound/mail_from.rs @@ -16,13 +16,17 @@ impl Session { pub async fn handle_mail_from(&mut self, from: MailFrom>) -> SessionResult<()> { let Some(helo_hostname) = &self.data.ehlo_hostname else { return self - .write(b"503 5.5.1 Polite people say EHLO first.\r\n") + .reply("503", "5.5.1", "Polite people say EHLO first.") .await; }; if self.data.mail_from.is_some() { return self - .write(b"503 5.5.1 Multiple MAIL FROM commands are not allowed.\r\n") + .reply( + "503", + "5.5.1", + "Multiple MAIL FROM commands are not allowed.", + ) .await; } @@ -49,7 +53,7 @@ impl Session { // Validate address let Ok(address) = EmailAddress::from_str(&from.address) else { return self - .write(b"550 5.7.1 Sender address is incorrect.\r\n") + .reply("550", "5.7.1", "Sender address is incorrect.") .await; }; @@ -63,13 +67,13 @@ impl Session { .result; if !matches!(result, IprevResult::Pass) { - let message = if matches!(result, IprevResult::TempError(_)) { - &b"451 4.7.25 Temporary error validating reverse DNS.\r\n"[..] + let (code, ext, msg) = if matches!(result, IprevResult::TempError(_)) { + ("451", "4.7.25", "Temporary error validating reverse DNS.") } else { - &b"550 5.7.25 Reverse DNS validation failed.\r\n"[..] + ("550", "5.7.25", "Reverse DNS validation failed.") }; - return self.write(message).await; + return self.reply(code, ext, msg).await; } } @@ -89,21 +93,22 @@ impl Session { SpfResult::Pass | SpfResult::Neutral | SpfResult::None => {} SpfResult::TempError => { return self - .write(b"451 4.7.24 Temporary SPF validation error.\r\n") + .reply("451", "4.7.24", "Temporary SPF validation error.") .await; } SpfResult::Fail | SpfResult::PermError | SpfResult::SoftFail => { - let mut msg = "550 5.7.23 SPF validation failed".to_string(); + let mut msg = "SPF validation failed".to_string(); if let Some(v) = output.explanation() { write!(msg, ": {v}.").ok(); } write!(msg, "\r\n").ok(); - return self.write(msg.as_bytes()).await; + + return self.reply("550", "5.7.23", &msg).await; } } } - self.write(b"250 2.1.0 OK\r\n").await?; + self.reply("250", "2.1.0", "OK").await?; self.data.mail_from = Some(address); Ok(()) } diff --git a/ic-bn-lib/src/smtp/inbound/mod.rs b/ic-bn-lib/src/smtp/inbound/mod.rs index 2bc0ecf..4ed4e96 100644 --- a/ic-bn-lib/src/smtp/inbound/mod.rs +++ b/ic-bn-lib/src/smtp/inbound/mod.rs @@ -13,6 +13,7 @@ use std::{ }; use ahash::AHashSet; +use bytes::Bytes; use fqdn::FQDN; use ic_bn_lib_common::types::http::TlsInfo; use mail_auth::MessageAuthenticator; @@ -76,29 +77,37 @@ pub enum SessionTlsMode { /// SMTP session config pub struct SessionConfig { - pub hostname: String, + hostname: String, + greeting: Bytes, + pub max_message_size: usize, pub max_recipients: usize, pub max_session_duration: Duration, pub max_session_data: usize, pub max_errors: usize, pub max_messages_per_session: usize, + pub verify_ehlo_hostname: bool, pub verify_sender_domain: bool, pub verify_reverse_ip: bool, pub verify_spf: bool, pub helo_delay: Option, + pub timeout: Duration, pub tls_mode: SessionTlsMode, + pub authenticator: Arc, pub recipient_resolver: Arc, pub delivery_agent: Arc, } -impl Default for SessionConfig { - fn default() -> Self { +impl SessionConfig { + pub fn new(hostname: &str) -> Self { + let greeting = format!("220 {hostname} ESMTP IC SMTP Gateway\r\n"); + Self { - hostname: "".into(), + hostname: hostname.into(), + greeting: Bytes::from(greeting), max_message_size: 10 * 1024 * 1024, max_recipients: 5, max_session_duration: Duration::from_secs(600), @@ -118,9 +127,7 @@ impl Default for SessionConfig { delivery_agent: Arc::new(DummyDeliveryAgent), } } -} -impl SessionConfig { pub const fn tls_enabled(&self) -> bool { matches!( self.tls_mode, @@ -224,14 +231,35 @@ impl Session { mod tests { use std::str::FromStr; + use async_trait::async_trait; + use fqdn::fqdn; use tokio_util::sync::CancellationToken; + use crate::smtp::{DeliveryError, Message}; + use super::*; + #[derive(Debug)] + pub struct TestDeliveryAgent(Option, Option); + + #[async_trait] + impl DeliversMail for TestDeliveryAgent { + async fn deliver_mail(&self, message: Message) -> Result<(), DeliveryError> { + if let Some(e) = &self.1 { + return Err(e.clone()); + } + + if let Some(v) = &self.0 { + assert_eq!(v, &message); + } + + Ok(()) + } + } + fn create_session(stream: S, helo_delay: Option) -> Session { - let mut cfg = SessionConfig::default(); - cfg.hostname = "test".into(); - cfg.max_errors = 3; + let mut cfg = SessionConfig::new("test"); + cfg.max_errors = 5; cfg.max_message_size = 512; cfg.helo_delay = helo_delay; cfg.max_messages_per_session = 3; @@ -245,8 +273,11 @@ mod tests { let mut builder = tokio_test::io::Builder::new(); builder.write(b"220 test ESMTP IC SMTP Gateway\r\n") + .read(b"HELO foo.bar\r\n") + .write(b"250 test you had me at HELO\r\n") .read(b"EHLO foo.bar\r\n") .write(b"250-test you had me at EHLO\r\n250-SMTPUTF8\r\n250-ENHANCEDSTATUSCODES\r\n250-CHUNKING\r\n250 8BITMIME\r\n"); + builder } @@ -261,9 +292,45 @@ mod tests { .write(b"250 2.0.0 Message (13 bytes) queued with id 00000000-0000-0000-0000-000000000000\r\n"); } + #[tokio::test] + async fn test_ehlo_required() { + let stream = tokio_test::io::Builder::new() + .write(b"220 test ESMTP IC SMTP Gateway\r\n") + .read(b"MAIL FROM:\r\n") + .write(b"503 5.5.1 Polite people say EHLO first.\r\n") + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + #[tokio::test] async fn test_basic_session() { let mut builder = create_basic_stream(); + builder + .read(b"RCPT TO:\r\n") + .write(b"503 5.5.1 MAIL FROM is required first.\r\n") + .read(b"MAIL FROM:\r\n") + .write(b"250 2.1.0 OK\r\n") + .read(b"MAIL FROM:\r\n") + .write(b"503 5.5.1 Multiple MAIL FROM commands are not allowed.\r\n") + .read(b"DATA\r\n") + .write(b"503 5.5.1 RCPT TO is required first.\r\n") + .read(b"RSET\r\n") + .write(b"250 2.0.0 OK\r\n") + .read(b"NOOP\r\n") + .write(b"250 2.0.0 OK\r\n") + .read(b"FOOB\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"HELP\r\n") + .write(b"502 5.5.1 Command not implemented.\r\n"); + stream_send_message(&mut builder); let stream = builder .read(b"QUIT\r\n") @@ -293,6 +360,78 @@ mod tests { )); } + #[tokio::test] + async fn test_data() { + let mut builder = create_basic_stream(); + stream_send_message(&mut builder); + + let stream = builder + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let agent = TestDeliveryAgent( + Some(Message { + id: Uuid::nil(), + ehlo_hostname: fqdn!("foo.bar"), + mail_from: EmailAddress::from_str("foo@bar").unwrap(), + rcpt_to: vec![EmailAddress::from_str("dead@beef").unwrap()], + body: b"foobarmessage".to_vec(), + }), + None, + ); + + let mut cfg = SessionConfig::new("test"); + cfg.delivery_agent = Arc::new(agent); + let mut session = Session::new(IpAddr::from_str("1.1.1.1").unwrap(), stream, Arc::new(cfg)); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + + #[tokio::test] + async fn test_bdat() { + let stream = create_basic_stream() + .read(b"MAIL FROM:\r\n") + .write(b"250 2.1.0 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"BDAT 10\r\n") + .read(b"0123456789") + .write(b"250 2.6.0 Chunk accepted.\r\n") + .read(b"BDAT 10\r\n") + .read(b"9876543210") + .write(b"250 2.6.0 Chunk accepted.\r\n") + .read(b"BDAT 10 LAST\r\n") + .read(b"0123456789") + .write(b"250 2.0.0 Message (30 bytes) queued with id 00000000-0000-0000-0000-000000000000\r\n") + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let agent = TestDeliveryAgent( + Some(Message { + id: Uuid::nil(), + ehlo_hostname: fqdn!("foo.bar"), + mail_from: EmailAddress::from_str("foo@bar").unwrap(), + rcpt_to: vec![EmailAddress::from_str("baz@baz").unwrap()], + body: b"012345678998765432100123456789".to_vec(), + }), + None, + ); + + let mut cfg = SessionConfig::new("test"); + cfg.delivery_agent = Arc::new(agent); + let mut session = Session::new(IpAddr::from_str("1.1.1.1").unwrap(), stream, Arc::new(cfg)); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + #[tokio::test] async fn test_max_recipients() { let stream = create_basic_stream() @@ -366,4 +505,30 @@ mod tests { SessionError::TooManyMessagesPerSession )); } + + #[tokio::test] + async fn test_max_errors() { + let stream = tokio_test::io::Builder::new() + .write(b"220 test ESMTP IC SMTP Gateway\r\n") + .read(b"FOO\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"FOO\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"FOO\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"FOO\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"FOO\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"FOO\r\n") + .write(b"452 4.3.2 Too many errors.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::TooManyErrors + )); + } } diff --git a/ic-bn-lib/src/smtp/inbound/rcpt_to.rs b/ic-bn-lib/src/smtp/inbound/rcpt_to.rs index ff199d2..642dfa5 100644 --- a/ic-bn-lib/src/smtp/inbound/rcpt_to.rs +++ b/ic-bn-lib/src/smtp/inbound/rcpt_to.rs @@ -18,7 +18,7 @@ impl Session { pub async fn handle_rcpt_to(&mut self, to: RcptTo>) -> SessionResult<()> { if self.data.mail_from.is_none() { return self - .write(b"503 5.5.1 MAIL FROM is required first.\r\n") + .reply("503", "5.5.1", "MAIL FROM is required first.") .await; } @@ -32,15 +32,15 @@ impl Session { } let Ok(address) = EmailAddress::from_str(&to.address) else { - return self.write(b"550 5.1.2 Incorrect address.\r\n").await; + return self.reply("550", "5.1.2", "Incorrect address.").await; }; if self.data.rcpt_to.contains(&address) { - return self.write(b"250 2.1.5 OK\r\n").await; + return self.reply("250", "2.1.5", "OK").await; } if self.data.rcpt_to.len() >= self.cfg.max_recipients { - return self.write(b"455 4.5.3 Too many recipients.\r\n").await; + return self.reply("455", "4.5.3", "Too many recipients.").await; } match self @@ -63,19 +63,21 @@ impl Session { Err(e) => match e { RecipientResolveError::UnknownDomain => { - return self.write(b"550 5.1.2 Relay not allowed.\r\n").await; + return self + .reply("550", "5.1.2", "Unknown recipient domain.") + .await; } RecipientResolveError::UnknownRecipient => { - return self.write(b"550 5.1.2 Mailbox does not exist.\r\n").await; + return self.reply("550", "5.1.2", "Mailbox does not exist.").await; } RecipientResolveError::Other(_) => { return self - .write(b"451 4.4.3 Unable to verify address at this time.\r\n") + .reply("451", "4.4.3", "Unable to verify address at this time.") .await; } }, } - self.write(b"250 2.1.5 OK\r\n").await + self.reply("250", "2.1.5", "OK").await } } diff --git a/ic-bn-lib/src/smtp/inbound/session.rs b/ic-bn-lib/src/smtp/inbound/session.rs index 6d44c8e..05656a6 100644 --- a/ic-bn-lib/src/smtp/inbound/session.rs +++ b/ic-bn-lib/src/smtp/inbound/session.rs @@ -23,6 +23,8 @@ use crate::{ }, }; +const MAX_REPLY_LEN: usize = 256; + #[allow(clippy::too_many_arguments)] impl Session { /// Writes given bytes to the session & flushes the buffer @@ -32,6 +34,53 @@ impl Session { Ok(()) } + /// Replies with given codes & message. + /// + /// It accepts replies up to `MAX_REPLY_LEN` long since it uses + /// a stack array to avoid heap allocation for performance reasons. + /// If ever this module would need more - increase the constant. + pub(crate) async fn reply(&mut self, code: &str, ext: &str, msg: &str) -> SessionResult<()> { + let len = code.len() + ext.len() + msg.len() + 4; + assert!( + len <= MAX_REPLY_LEN, + "Reply longer than supported - increase MAX_REPLY_LEN" + ); + + // Poor man's `format!` + let mut buf = [0; MAX_REPLY_LEN]; + let (mut i, mut j) = (0, code.len()); + buf[i..j].copy_from_slice(code.as_bytes()); + buf[j] = b' '; + i += code.len() + 1; + j += ext.len() + 1; + buf[i..j].copy_from_slice(ext.as_bytes()); + buf[j] = b' '; + i += ext.len() + 1; + j += msg.len() + 1; + buf[i..j].copy_from_slice(msg.as_bytes()); + buf[j] = b'\r'; + buf[j + 1] = b'\n'; + + self.write(&buf[..len]).await + } + + pub(crate) async fn ext_unsupported(&mut self, ext: &str) -> SessionResult<()> { + self.reply( + "501", + "5.5.4", + &format!("{ext} extension is not supported."), + ) + .await + } + + pub(crate) async fn message_too_big(&mut self) -> SessionResult<()> { + let msg = format!( + "Message too big for, we accept up to {} bytes.", + self.cfg.max_message_size + ); + return self.reply("552", "5.3.4", &msg).await; + } + /// Sends greeting message async fn greeting(&mut self) -> SessionResult<()> { // If we have HELO delay configured - try to read from the stream for up to this duration. @@ -43,8 +92,12 @@ impl Session { match self.stream.read(&mut buf).timeout(v).await { Ok(Ok(bytes_read)) => { if bytes_read > 0 { - self.write(b"501 5.7.1 Client sent command before greeting banner.\r\n") - .await?; + self.reply( + "501", + "5.7.1", + "Client sent command before greeting banner.", + ) + .await?; return Err(SessionError::SendsBeforeGreeting); } } @@ -53,50 +106,46 @@ impl Session { } } - let greeting = format!("220 {} ESMTP IC SMTP Gateway\r\n", self.cfg.hostname); - self.write(greeting.as_bytes()).await?; - - Ok(()) + self.write(&self.cfg.greeting.clone()).await } async fn handle_error(&mut self, error: SmtpError) -> SessionResult<()> { - match error { + let (code, ext, msg) = match error { SmtpError::UnknownCommand | SmtpError::InvalidResponse { .. } => { - self.write(b"500 5.5.1 Invalid command.\r\n").await?; + ("500", "5.5.1", "Invalid command.".to_string()) } SmtpError::InvalidSenderAddress => { - self.write(b"501 5.1.8 Bad sender's system address.\r\n") - .await?; - } - SmtpError::InvalidRecipientAddress => { - self.write(b"501 5.1.3 Bad destination mailbox address syntax.\r\n") - .await?; + ("501", "5.1.8", "Bad sender's system address.".to_string()) } + SmtpError::InvalidRecipientAddress => ( + "501", + "5.1.3", + "Bad destination mailbox address syntax.".to_string(), + ), SmtpError::SyntaxError { syntax } => { - self.write(format!("501 5.5.2 Syntax error, expected: {syntax}\r\n").as_bytes()) - .await?; + ("501", "5.5.2", format!("Syntax error, expected: {syntax}")) } SmtpError::InvalidParameter { param } => { - self.write(format!("501 5.5.4 Invalid parameter {param:?}.\r\n").as_bytes()) - .await?; + ("501", "5.5.4", format!("Invalid parameter {param:?}.")) } SmtpError::UnsupportedParameter { param } => { - self.write(format!("504 5.5.4 Unsupported parameter {param:?}.\r\n").as_bytes()) - .await?; + ("504", "5.5.4", format!("Unsupported parameter {param:?}.")) } - SmtpError::ResponseTooLong => { - self.state = SessionState::RequestTooLarge(DummyLineReceiver::default()); - } - SmtpError::NeedsMoreData { .. } => {} - } + // These are handled one level above + SmtpError::ResponseTooLong | SmtpError::NeedsMoreData { .. } => unreachable!(), + }; - Ok(()) + self.counters.errors += 1; + self.reply(code, ext, &msg).await } async fn handle_request(&mut self, req: Request>) -> SessionResult<()> { match req { - Request::Ehlo { host } | Request::Helo { host } => { - self.handle_ehlo(&host).await?; + Request::Ehlo { host } => { + self.handle_ehlo(&host, true).await?; + } + Request::Helo { host } => { + self.handle_ehlo(&host, false).await?; } Request::Mail { from } => { self.handle_mail_from(from).await?; @@ -106,17 +155,17 @@ impl Session { } Request::Rset => { self.reset_message(); - self.write(b"250 2.0.0 OK\r\n").await?; + self.reply("250", "2.0.0", "OK").await?; } Request::Quit => { - self.write(b"221 2.0.0 Bye.\r\n").await?; + self.reply("221", "2.0.0", "Bye.").await?; return Err(SessionError::Quit); } Request::Noop { .. } => { - self.write(b"250 2.0.0 OK\r\n").await?; + self.reply("250", "2.0.0", "OK").await?; } _ => { - self.write(b"502 5.5.1 Command not implemented.\r\n") + self.reply("502", "5.5.1", "Command not implemented.") .await?; self.counters.errors += 1; } @@ -128,7 +177,7 @@ impl Session { async fn ingest(&mut self, bytes: &[u8]) -> SessionResult { // Check if we are over session transfer quota if self.counters.bytes_ingested + bytes.len() >= self.cfg.max_session_data { - self.write(b"452 4.7.28 Session transfer quota exceeded.\r\n") + self.reply("452", "4.7.28", "Session transfer quota exceeded.") .await?; return Err(SessionError::TransferQuotaExceeded( self.cfg.max_session_data, @@ -137,7 +186,7 @@ impl Session { // Check if we are over session time quota if Instant::now() > self.counters.valid_until { - self.write(b"452 4.3.2 Session open for too long.\r\n") + self.reply("452", "4.3.2", "Session open for too long.") .await?; return Err(SessionError::TtlExceeded( self.cfg.max_session_duration.as_secs(), @@ -145,8 +194,8 @@ impl Session { } // Check if we are over error limit - if self.counters.errors > self.cfg.max_errors { - self.write(b"452 4.3.2 Too many errors.\r\n").await?; + if self.counters.errors >= self.cfg.max_errors { + self.reply("452", "4.3.2", "Too many errors.").await?; return Err(SessionError::TooManyErrors); } @@ -200,13 +249,13 @@ impl Session { } Request::StartTls => { if self.tls_info.is_some() { - self.write(b"504 5.7.4 Already in TLS mode.\r\n").await?; + self.reply("504", "5.7.4", "Already in TLS mode.").await?; self.counters.errors += 1; } else if !self.cfg.tls_enabled() { - self.write(b"502 5.7.0 TLS not available.\r\n").await?; + self.reply("502", "5.7.0", "TLS not available.").await?; self.counters.errors += 1; } else { - self.write(b"220 2.0.0 Ready to start TLS.\r\n").await?; + self.reply("220", "2.0.0", "Ready to start TLS.").await?; return Ok(SessionUpgrade::StartTls); } } @@ -217,9 +266,12 @@ impl Session { // In case of NeedsMoreData error we just leave // and wait for new data to be ingested Err(SmtpError::NeedsMoreData { .. }) => break, + Err(SmtpError::ResponseTooLong) => { + state = SessionState::RequestTooLarge(DummyLineReceiver::default()); + continue; + } Err(e) => { self.handle_error(e).await?; - self.counters.errors += 1; } } } @@ -243,7 +295,7 @@ impl Session { if rx.is_last { self.queue_message().await?; } else { - self.write(b"250 2.6.0 Chunk accepted.\r\n").await?; + self.reply("250", "2.6.0", "Chunk accepted.").await?; } } else { self.data.message = Vec::with_capacity(0); @@ -257,7 +309,7 @@ impl Session { SessionState::RequestTooLarge(rx) => { // If line-feed found - issue error, otherwise keep ingesting if rx.ingest(&mut iter) { - self.write(b"554 5.3.4 Line is too long.\r\n").await?; + self.reply("554", "5.3.4", "Line is too long.").await?; state = SessionState::default(); self.counters.errors += 1; } else { @@ -308,7 +360,7 @@ impl Session { return Err(e.into()); } Err(_) => { - self.write(b"221 2.0.0 Disconnecting due to inactivity.\r\n").await?; + self.reply("221", "2.0.0", "Disconnecting due to inactivity.").await?; return Err(SessionError::Timeout); } } @@ -323,20 +375,6 @@ impl Session { Ok(SessionUpgrade::No) } - pub(crate) async fn ext_unsupported(&mut self, ext: &str) -> SessionResult<()> { - self.write(b"501 5.5.4 ").await?; - self.write(ext.as_bytes()).await?; - return self.write(b" extension is not supported.\r\n").await; - } - - pub(crate) async fn message_too_big(&mut self) -> SessionResult<()> { - let msg = format!( - "552 5.3.4 Message too big for, we accept up to {} bytes.\r\n", - self.cfg.max_message_size - ); - return self.write(msg.as_bytes()).await; - } - async fn queue_message(&mut self) -> SessionResult<()> { #[cfg(not(test))] let id = Uuid::now_v7(); @@ -356,22 +394,24 @@ impl Session { }; if let Err(e) = self.cfg.delivery_agent.deliver_mail(message).await { - let msg = match e { + let (code, ext, msg) = match e { DeliveryError::Permanent(v) => { - format!("550 5.5.0 Permanent delivery error: {v}") + ("550", "5.5.0", format!("Permanent delivery error: {v}")) } DeliveryError::Temporary(v) => { - format!("450 4.5.0 Temporary delivery error: {v}") + ("450", "4.5.0", format!("Temporary delivery error: {v}")) } }; - self.write(msg.as_bytes()).await?; + self.reply(code, ext, &msg).await?; self.reset_message(); return Ok(()); } - self.write( - format!("250 2.0.0 Message ({message_size} bytes) queued with id {id}\r\n").as_bytes(), + self.reply( + "250", + "2.0.0", + &format!("Message ({message_size} bytes) queued with id {id}"), ) .await?; @@ -382,11 +422,15 @@ impl Session { async fn can_accept_message(&mut self) -> SessionResult { if self.counters.messages_queued >= self.cfg.max_messages_per_session { - self.write(b"452 4.4.5 Maximum number of messages per session exceeded.\r\n") - .await?; + self.reply( + "452", + "4.4.5", + "Maximum number of messages per session exceeded.", + ) + .await?; return Err(SessionError::TooManyMessagesPerSession); } else if self.data.rcpt_to.is_empty() { - self.write(b"503 5.5.1 RCPT TO is required first.\r\n") + self.reply("503", "5.5.1", "RCPT TO is required first.") .await?; self.counters.errors += 1; return Ok(false); @@ -396,15 +440,14 @@ impl Session { } /// Resets the message-related fields to their initial state - fn reset_message(&mut self) { + pub(crate) fn reset_message(&mut self) { self.data.mail_from = None; self.data.rcpt_to.clear(); self.data.message.clear(); } + /// Closes the connection pub async fn shutdown(&mut self) -> SessionResult<()> { - self.write(b"421 4.3.0 Server shutting down.\r\n").await?; - self.stream .shutdown() .timeout(Duration::from_secs(10)) diff --git a/ic-bn-lib/src/smtp/mod.rs b/ic-bn-lib/src/smtp/mod.rs index 1ffa1d2..24756a8 100644 --- a/ic-bn-lib/src/smtp/mod.rs +++ b/ic-bn-lib/src/smtp/mod.rs @@ -32,7 +32,7 @@ pub enum RecipientResolveError { } /// Delivery error -#[derive(thiserror::Error, Debug)] +#[derive(thiserror::Error, Clone, Debug)] pub enum DeliveryError { #[error("{0}")] Temporary(String), @@ -41,7 +41,7 @@ pub enum DeliveryError { } /// Low-level E-Mail representation -#[derive(Debug)] +#[derive(Debug, Eq, PartialEq, Hash)] pub struct Message { pub id: Uuid, pub ehlo_hostname: FQDN, diff --git a/ic-bn-lib/tools/smtp_server.rs b/ic-bn-lib/tools/smtp_server.rs index 38f7f12..feddcd0 100644 --- a/ic-bn-lib/tools/smtp_server.rs +++ b/ic-bn-lib/tools/smtp_server.rs @@ -7,12 +7,12 @@ use tokio_util::sync::CancellationToken; async fn main() { tracing_subscriber::fmt().init(); - let mut cfg = SessionConfig::default(); - cfg.helo_delay = Some(Duration::from_secs(1)); - cfg.hostname = "mail.icp.net".into(); + let mut cfg = SessionConfig::new("mail.icp.net"); + //cfg.helo_delay = Some(Duration::from_secs(1)); //params.max_message_size = 16; //params.max_session_duration = Duration::from_secs(30); //params.max_session_data = 16; + cfg.max_errors = 3; let server = Server::new(SocketAddr::from_str("127.0.0.1:1025").unwrap(), cfg).unwrap(); From ceac4cb1b2dc0ca42aa255d53b70923b096e0c0a Mon Sep 17 00:00:00 2001 From: Igor Novgorodov Date: Sun, 17 May 2026 17:03:31 +0200 Subject: [PATCH 6/8] Add STARTTLS test, more work --- ic-bn-lib/src/smtp/address.rs | 12 +- ic-bn-lib/src/smtp/inbound/mail_from.rs | 10 ++ ic-bn-lib/src/smtp/inbound/mod.rs | 215 +++++++++++++++++++++--- ic-bn-lib/src/smtp/inbound/rcpt_to.rs | 9 +- ic-bn-lib/src/smtp/inbound/session.rs | 12 +- ic-bn-lib/tools/smtp_server.rs | 20 ++- 6 files changed, 245 insertions(+), 33 deletions(-) diff --git a/ic-bn-lib/src/smtp/address.rs b/ic-bn-lib/src/smtp/address.rs index 4db283b..5d76b6d 100644 --- a/ic-bn-lib/src/smtp/address.rs +++ b/ic-bn-lib/src/smtp/address.rs @@ -3,7 +3,7 @@ use std::{fmt::Display, str::FromStr}; use derive_new::new; use fqdn::FQDN; -#[derive(thiserror::Error, Debug, PartialEq, Eq)] +#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq)] pub enum EmailAddressError { #[error("@ is missing")] AtMissing, @@ -16,7 +16,7 @@ pub enum EmailAddressError { /// Currently we don't validate the local part at all /// and just consider everything to the right from the /// rightmost @ as a domain part. -#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, new)] +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, new)] pub struct EmailAddress { pub local: String, pub domain: FQDN, @@ -47,6 +47,14 @@ impl FromStr for EmailAddress { } } +impl TryFrom<&str> for EmailAddress { + type Error = EmailAddressError; + + fn try_from(value: &str) -> Result { + Self::from_str(value) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/ic-bn-lib/src/smtp/inbound/mail_from.rs b/ic-bn-lib/src/smtp/inbound/mail_from.rs index 45bba6c..261d105 100644 --- a/ic-bn-lib/src/smtp/inbound/mail_from.rs +++ b/ic-bn-lib/src/smtp/inbound/mail_from.rs @@ -30,6 +30,16 @@ impl Session { .await; } + if self.cfg.tls_required() && self.tls_info.is_none() { + return self + .reply( + "503", + "5.5.1", + "TLS is required to submit mail on this server.", + ) + .await; + } + if (from.flags & (MAIL_BY_NOTIFY | MAIL_BY_RETURN)) != 0 { return self.ext_unsupported("DELIVERBY").await; } diff --git a/ic-bn-lib/src/smtp/inbound/mod.rs b/ic-bn-lib/src/smtp/inbound/mod.rs index 4ed4e96..de6b4f9 100644 --- a/ic-bn-lib/src/smtp/inbound/mod.rs +++ b/ic-bn-lib/src/smtp/inbound/mod.rs @@ -12,7 +12,6 @@ use std::{ time::{Duration, Instant}, }; -use ahash::AHashSet; use bytes::Bytes; use fqdn::FQDN; use ic_bn_lib_common::types::http::TlsInfo; @@ -170,7 +169,7 @@ impl Default for SessionState { pub struct SessionData { pub ehlo_hostname: Option, pub mail_from: Option, - pub rcpt_to: AHashSet, + pub rcpt_to: Vec, pub message: Vec, } @@ -229,13 +228,23 @@ impl Session { #[cfg(test)] mod tests { - use std::str::FromStr; + use std::{net::SocketAddr, str::FromStr}; use async_trait::async_trait; use fqdn::fqdn; + use rustls::{ClientConfig, pki_types::ServerName}; + use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; + use tokio_rustls::TlsConnector; use tokio_util::sync::CancellationToken; - use crate::smtp::{DeliveryError, Message}; + use crate::{ + smtp::{ + DeliveryError, Message, RecipientPolicy, RecipientResolveError, + inbound::manager::SessionManager, + }, + tests::{TEST_CERT_1, TEST_KEY_1}, + tls::{resolver::StubResolver, verify::NoopServerCertVerifier}, + }; use super::*; @@ -257,6 +266,32 @@ mod tests { } } + #[derive(Debug)] + pub struct TestRecipientResolver( + EmailAddress, + Option, + Option>, + ); + + #[async_trait] + impl ResolvesRecipient for TestRecipientResolver { + async fn resolve_recipient( + &self, + rcpt: &EmailAddress, + ) -> Result { + assert_eq!(rcpt, &self.0); + if let Some(v) = &self.1 { + return Ok(RecipientPolicy::Rewrite(v.clone())); + } + + if let Some(v) = &self.2 { + return Ok(RecipientPolicy::Expand(v.clone())); + } + + Ok(RecipientPolicy::Accept) + } + } + fn create_session(stream: S, helo_delay: Option) -> Session { let mut cfg = SessionConfig::new("test"); cfg.max_errors = 5; @@ -360,6 +395,55 @@ mod tests { )); } + #[tokio::test] + async fn test_bdat() { + let stream = create_basic_stream() + .read(b"MAIL FROM:\r\n") + .write(b"250 2.1.0 OK\r\n") + .read(b"RCPT TO:\r\n") + .write(b"250 2.1.5 OK\r\n") + .read(b"BDAT 10\r\n") + .read(b"01234") + .read(b"56789") + .write(b"250 2.6.0 Chunk accepted.\r\n") + .read(b"BDAT 10\r\n") + .read(b"9876543210") + .write(b"250 2.6.0 Chunk accepted.\r\n") + .read(b"BDAT 10 LAST\r\n") + .read(b"0123456789") + .write(b"250 2.0.0 Message (30 bytes) queued with id 00000000-0000-0000-0000-000000000000\r\n") + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let agent = TestDeliveryAgent( + Some(Message { + id: Uuid::nil(), + ehlo_hostname: fqdn!("foo.bar"), + mail_from: "foo@bar".try_into().unwrap(), + rcpt_to: vec!["bar@baz".try_into().unwrap()], + body: b"012345678998765432100123456789".to_vec(), + }), + None, + ); + let resolver = TestRecipientResolver( + "dead@beef".try_into().unwrap(), + Some("bar@baz".try_into().unwrap()), + None, + ); + + let mut cfg = SessionConfig::new("test"); + cfg.delivery_agent = Arc::new(agent); + cfg.recipient_resolver = Arc::new(resolver); + + let mut session = Session::new(IpAddr::from_str("1.1.1.1").unwrap(), stream, Arc::new(cfg)); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + #[tokio::test] async fn test_data() { let mut builder = create_basic_stream(); @@ -375,14 +459,21 @@ mod tests { id: Uuid::nil(), ehlo_hostname: fqdn!("foo.bar"), mail_from: EmailAddress::from_str("foo@bar").unwrap(), - rcpt_to: vec![EmailAddress::from_str("dead@beef").unwrap()], + rcpt_to: vec![EmailAddress::from_str("bar@baz").unwrap()], body: b"foobarmessage".to_vec(), }), None, ); + let resolver = TestRecipientResolver( + "dead@beef".try_into().unwrap(), + Some("bar@baz".try_into().unwrap()), + None, + ); let mut cfg = SessionConfig::new("test"); cfg.delivery_agent = Arc::new(agent); + cfg.recipient_resolver = Arc::new(resolver); + let mut session = Session::new(IpAddr::from_str("1.1.1.1").unwrap(), stream, Arc::new(cfg)); assert!(matches!( @@ -392,21 +483,11 @@ mod tests { } #[tokio::test] - async fn test_bdat() { - let stream = create_basic_stream() - .read(b"MAIL FROM:\r\n") - .write(b"250 2.1.0 OK\r\n") - .read(b"RCPT TO:\r\n") - .write(b"250 2.1.5 OK\r\n") - .read(b"BDAT 10\r\n") - .read(b"0123456789") - .write(b"250 2.6.0 Chunk accepted.\r\n") - .read(b"BDAT 10\r\n") - .read(b"9876543210") - .write(b"250 2.6.0 Chunk accepted.\r\n") - .read(b"BDAT 10 LAST\r\n") - .read(b"0123456789") - .write(b"250 2.0.0 Message (30 bytes) queued with id 00000000-0000-0000-0000-000000000000\r\n") + async fn test_expand() { + let mut builder = create_basic_stream(); + stream_send_message(&mut builder); + + let stream = builder .read(b"QUIT\r\n") .write(b"221 2.0.0 Bye.\r\n") .build(); @@ -416,14 +497,28 @@ mod tests { id: Uuid::nil(), ehlo_hostname: fqdn!("foo.bar"), mail_from: EmailAddress::from_str("foo@bar").unwrap(), - rcpt_to: vec![EmailAddress::from_str("baz@baz").unwrap()], - body: b"012345678998765432100123456789".to_vec(), + rcpt_to: vec![ + EmailAddress::from_str("dead@beef").unwrap(), + EmailAddress::from_str("dead@dead").unwrap(), + EmailAddress::from_str("bar@bax").unwrap(), + ], + body: b"foobarmessage".to_vec(), }), None, ); + let resolver = TestRecipientResolver( + "dead@beef".try_into().unwrap(), + None, + Some(vec![ + "dead@dead".try_into().unwrap(), + "bar@bax".try_into().unwrap(), + ]), + ); let mut cfg = SessionConfig::new("test"); cfg.delivery_agent = Arc::new(agent); + cfg.recipient_resolver = Arc::new(resolver); + let mut session = Session::new(IpAddr::from_str("1.1.1.1").unwrap(), stream, Arc::new(cfg)); assert!(matches!( @@ -531,4 +626,80 @@ mod tests { SessionError::TooManyErrors )); } + + #[tokio::test] + async fn test_starttls() { + rustls::crypto::ring::default_provider() + .install_default() + .unwrap(); + + let (stream1, mut stream2) = duplex(128); + let rustls_server_cfg = ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(Arc::new( + StubResolver::new(TEST_CERT_1.as_bytes(), TEST_KEY_1.as_bytes()).unwrap(), + )); + + let mut cfg = SessionConfig::new("test"); + cfg.tls_mode = SessionTlsMode::Required(Arc::new(rustls_server_cfg)); + + let session_manager = SessionManager::new(); + tokio::spawn(async move { + session_manager + .handle_connection( + stream1, + SocketAddr::from_str("1.1.1.1:123").unwrap(), + Arc::new(cfg), + CancellationToken::new(), + ) + .await; + }); + + let mut buf = vec![0; 256]; + + let r = stream2.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"220 test ESMTP IC SMTP Gateway\r\n"); + + // Make sure there's a 250-STARTTLS in EHLO + stream2.write_all(b"EHLO foo.bar\r\n").await.unwrap(); + let r = stream2.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"250-test you had me at EHLO\r\n250-STARTTLS\r\n250-SMTPUTF8\r\n250-ENHANCEDSTATUSCODES\r\n250-CHUNKING\r\n250 8BITMIME\r\n"); + + // Make sure TLS is required by server + stream2.write_all(b"MAIL FROM:\r\n").await.unwrap(); + let r = stream2.read(&mut buf).await.unwrap(); + assert_eq!( + &buf[..r], + b"503 5.5.1 TLS is required to submit mail on this server.\r\n" + ); + + stream2.write_all(b"STARTTLS\r\n").await.unwrap(); + let r = stream2.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"220 2.0.0 Ready to start TLS.\r\n"); + + let rustls_client_cfg = ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(NoopServerCertVerifier::default())) + .with_no_client_auth(); + let tls_connector = TlsConnector::from(Arc::new(rustls_client_cfg)); + let mut tls_stream = tls_connector + .connect(ServerName::try_from("foo").unwrap(), stream2) + .await + .unwrap(); + + // Make sure there's NO 250-STARTTLS in EHLO anymore inside TLS session + tls_stream.write_all(b"EHLO foo.bar\r\n").await.unwrap(); + let r = tls_stream.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"250-test you had me at EHLO\r\n250-SMTPUTF8\r\n250-ENHANCEDSTATUSCODES\r\n250-CHUNKING\r\n250 8BITMIME\r\n"); + + // No TLS-in-TLS allowed + tls_stream.write_all(b"STARTTLS\r\n").await.unwrap(); + let r = tls_stream.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"504 5.7.4 Already in TLS mode.\r\n"); + + // Now MAIL FROM should work + tls_stream.write_all(b"MAIL FROM:\r\n").await.unwrap(); + let r = tls_stream.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"250 2.1.0 OK\r\n"); + } } diff --git a/ic-bn-lib/src/smtp/inbound/rcpt_to.rs b/ic-bn-lib/src/smtp/inbound/rcpt_to.rs index 642dfa5..731a4d1 100644 --- a/ic-bn-lib/src/smtp/inbound/rcpt_to.rs +++ b/ic-bn-lib/src/smtp/inbound/rcpt_to.rs @@ -51,13 +51,14 @@ impl Session { { Ok(v) => match v { RecipientPolicy::Accept => { - self.data.rcpt_to.insert(address); + self.data.rcpt_to.push(address); } RecipientPolicy::Rewrite(new_address) => { - self.data.rcpt_to.insert(new_address); + self.data.rcpt_to.push(new_address); } - RecipientPolicy::Expand(new_addresses) => { - self.data.rcpt_to.extend(new_addresses); + RecipientPolicy::Expand(additional_addresses) => { + self.data.rcpt_to.push(address); + self.data.rcpt_to.extend(additional_addresses); } }, diff --git a/ic-bn-lib/src/smtp/inbound/session.rs b/ic-bn-lib/src/smtp/inbound/session.rs index 05656a6..c331541 100644 --- a/ic-bn-lib/src/smtp/inbound/session.rs +++ b/ic-bn-lib/src/smtp/inbound/session.rs @@ -13,6 +13,7 @@ use tokio::{ select, }; use tokio_util::{sync::CancellationToken, time::FutureExt}; +use tracing::debug; use uuid::Uuid; use crate::{ @@ -29,6 +30,7 @@ const MAX_REPLY_LEN: usize = 256; impl Session { /// Writes given bytes to the session & flushes the buffer pub async fn write(&mut self, bytes: &[u8]) -> SessionResult<()> { + debug!("{self}: Writing: {}", String::from_utf8_lossy(bytes)); self.stream.write_all(bytes).await?; self.stream.flush().await?; Ok(()) @@ -175,6 +177,8 @@ impl Session { } async fn ingest(&mut self, bytes: &[u8]) -> SessionResult { + debug!("{self}: Read: {}", String::from_utf8_lossy(bytes)); + // Check if we are over session transfer quota if self.counters.bytes_ingested + bytes.len() >= self.cfg.max_session_data { self.reply("452", "4.7.28", "Session transfer quota exceeded.") @@ -256,6 +260,7 @@ impl Session { self.counters.errors += 1; } else { self.reply("220", "2.0.0", "Ready to start TLS.").await?; + self.state = state; return Ok(SessionUpgrade::StartTls); } } @@ -354,7 +359,10 @@ impl Session { res = self.stream.read(&mut buf).timeout(self.cfg.timeout) => { match res { Ok(Ok(bytes_read)) => { - self.ingest(&buf[..bytes_read]).await?; + let upgrade = self.ingest(&buf[..bytes_read]).await?; + if matches!(upgrade, SessionUpgrade::StartTls) { + return Ok(upgrade); + } } Ok(Err(e)) => { return Err(e.into()); @@ -389,7 +397,7 @@ impl Session { id, ehlo_hostname: self.data.ehlo_hostname.clone().unwrap(), mail_from: self.data.mail_from.take().unwrap(), - rcpt_to: self.data.rcpt_to.drain().collect(), + rcpt_to: self.data.rcpt_to.drain(..).collect(), body: std::mem::take(&mut self.data.message), }; diff --git a/ic-bn-lib/tools/smtp_server.rs b/ic-bn-lib/tools/smtp_server.rs index feddcd0..c39de8f 100644 --- a/ic-bn-lib/tools/smtp_server.rs +++ b/ic-bn-lib/tools/smtp_server.rs @@ -1,18 +1,32 @@ -use std::{net::SocketAddr, str::FromStr, time::Duration}; +use std::{net::SocketAddr, str::FromStr, sync::Arc, time::Duration}; -use ic_bn_lib::smtp::{inbound::SessionConfig, server::Server}; +use ic_bn_lib::{ + smtp::{inbound::SessionConfig, server::Server}, + tests::{TEST_CERT_1, TEST_KEY_2}, + tls::resolver::StubResolver, +}; +use rustls::ServerConfig; use tokio_util::sync::CancellationToken; #[tokio::main] async fn main() { tracing_subscriber::fmt().init(); + rustls::crypto::ring::default_provider() + .install_default() + .unwrap(); + + let rustls_server_cfg = ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(Arc::new( + StubResolver::new(TEST_CERT_1.as_bytes(), TEST_KEY_2.as_bytes()).unwrap(), + )); let mut cfg = SessionConfig::new("mail.icp.net"); //cfg.helo_delay = Some(Duration::from_secs(1)); //params.max_message_size = 16; //params.max_session_duration = Duration::from_secs(30); //params.max_session_data = 16; - cfg.max_errors = 3; + cfg.tls_mode = ic_bn_lib::smtp::inbound::SessionTlsMode::Allowed(Arc::new(rustls_server_cfg)); let server = Server::new(SocketAddr::from_str("127.0.0.1:1025").unwrap(), cfg).unwrap(); From 682cad3433bba323cac8f02531c72327fd3bfbcd Mon Sep 17 00:00:00 2001 From: Igor Novgorodov Date: Mon, 18 May 2026 11:56:53 +0200 Subject: [PATCH 7/8] More tests, cleanups --- ic-bn-lib/src/smtp/ic/smtp.rs | 294 ------------------------ ic-bn-lib/src/smtp/inbound/ehlo.rs | 5 +- ic-bn-lib/src/smtp/inbound/mail_from.rs | 2 +- ic-bn-lib/src/smtp/inbound/manager.rs | 49 ++-- ic-bn-lib/src/smtp/inbound/mod.rs | 101 +++++--- ic-bn-lib/src/smtp/inbound/session.rs | 19 +- ic-bn-lib/src/smtp/server.rs | 27 +-- 7 files changed, 135 insertions(+), 362 deletions(-) delete mode 100644 ic-bn-lib/src/smtp/ic/smtp.rs diff --git a/ic-bn-lib/src/smtp/ic/smtp.rs b/ic-bn-lib/src/smtp/ic/smtp.rs deleted file mode 100644 index 678006c..0000000 --- a/ic-bn-lib/src/smtp/ic/smtp.rs +++ /dev/null @@ -1,294 +0,0 @@ -//! Candid types and submit logic for the SMTP gateway ↔ canister protocol. - -use candid::{CandidType, Decode, Deserialize, Encode, Principal}; -use ic_bn_lib::ic_agent::{Agent, AgentError}; -use tracing::{debug, warn}; - -use crate::smtp::ReceivedMail; - -/// Candid `Header`. -#[derive(Clone, Debug, CandidType, Deserialize)] -pub struct Header { - pub name: String, - pub value: String, -} - -/// Candid `Message`. -#[derive(Clone, Debug, CandidType, Deserialize)] -pub struct Message { - pub headers: Vec
, - pub body: Vec, -} - -/// Candid `Address`. -#[derive(Clone, Debug, CandidType, Deserialize)] -pub struct Address { - pub user: String, - pub domain: String, -} - -/// Candid `Envelope`. -#[derive(Clone, Debug, CandidType, Deserialize)] -pub struct Envelope { - pub from: Address, - pub to: Address, -} - -/// Candid `SmtpRequest`. -#[derive(Clone, Debug, CandidType, Deserialize)] -pub struct SmtpRequest { - pub message: Option, - pub envelope: Option, - pub gateway_flags: Option>, -} - -/// Candid `SmtpRequestError` (`code` is `nat64` on the wire in typical canisters). -#[derive(Clone, Debug, CandidType, Deserialize)] -pub struct SmtpRequestError { - pub code: u64, - pub message: String, -} - -/// Candid `SmtpResponse` — `Ok` carries an empty record. -#[derive(Clone, Debug, CandidType, Deserialize)] -pub enum SmtpResponse { - Ok(SmtpOk), - Err(SmtpRequestError), -} - -/// Empty record for variant `Ok`. -#[derive(Clone, Debug, CandidType, Deserialize)] -pub struct SmtpOk {} - -fn address_from_smtp_path(path: &str) -> Address { - let path = path.trim(); - if path.is_empty() { - return Address { - user: String::new(), - domain: String::new(), - }; - } - match path.rsplit_once('@') { - Some((user, domain)) => Address { - user: user.to_string(), - domain: domain.to_string(), - }, - None => Address { - user: path.to_string(), - domain: String::new(), - }, - } -} - -/// Split RFC 5322 message into header block and body; parse headers with line unfolding. -pub fn parse_rfc5322_message(raw: &[u8]) -> Result { - let (header_end, body_start) = raw - .windows(4) - .position(|w| w == b"\r\n\r\n") - .map(|i| (i, i + 4)) - .or_else(|| { - raw.windows(2) - .position(|w| w == b"\n\n") - .map(|i| (i, i + 2)) - }) - .ok_or_else(|| "message has no header/body separator".to_string())?; - - let headers_src = std::str::from_utf8(&raw[..header_end]) - .map_err(|_| "message headers are not valid UTF-8".to_string())?; - let headers = parse_headers_unfolded(headers_src)?; - let body = raw[body_start..].to_vec(); - Ok(Message { headers, body }) -} - -fn parse_headers_unfolded(block: &str) -> Result, String> { - let lines = unfold_header_block(block); - let mut out = Vec::new(); - for line in lines { - let line = line.trim_end_matches(['\r', '\n']); - if line.is_empty() { - continue; - } - let Some((name, value)) = line.split_once(':') else { - return Err(format!("bad header line: {line:?}")); - }; - let name = name.trim().to_string(); - if name.is_empty() { - return Err("empty header name".to_string()); - } - let value = value.trim_start_matches([' ', '\t']).to_string(); - out.push(Header { name, value }); - } - Ok(out) -} - -/// RFC 5322 unfolding: continuation lines start with WSP. -fn unfold_header_block(block: &str) -> Vec { - let mut merged: Vec = Vec::new(); - for raw_line in block.split('\n') { - let line = raw_line.trim_end_matches('\r'); - let first = line.chars().next(); - let is_continuation = matches!(first, Some(' ' | '\t')); - if is_continuation && !merged.is_empty() { - let last = merged.last_mut().expect("merged non-empty"); - last.push(' '); - last.push_str(line.trim_start_matches([' ', '\t'])); - } else { - merged.push(line.to_string()); - } - } - merged -} - -/// Map canister SMTP-style error codes to an SMTP text reply (code + message for the client). -pub fn smtp_line_from_canister_err(e: &SmtpRequestError) -> (u16, String) { - let c = e.code; - let code = if (400..600).contains(&c) { - c as u16 - } else if c < 400 { - 451 - } else { - 554 - }; - (code, e.message.clone()) -} - -fn agent_err_to_string(e: AgentError) -> String { - e.to_string() -} - -/// Failure from [`submit_mail`]: transport/parse errors or a canister rejection for one recipient. -#[derive(Debug)] -pub enum SubmitMailError { - Other(String), - Rejected { - code: u16, - message: String, - failed_recipient: String, - }, -} - -impl SubmitMailError { - /// SMTP session reply text: `" "` for rejections (matches [`crate::smtp::session::handler_error_to_response`]). - pub fn into_handler_error(self) -> String { - match self { - SubmitMailError::Other(s) => s, - SubmitMailError::Rejected { code, message, .. } => format!("{code} {message}"), - } - } -} - -impl From for SubmitMailError { - fn from(s: String) -> Self { - SubmitMailError::Other(s) - } -} - -/// Submit mail: optional `smtp_request_validate` (query) per recipient, then `smtp_request` (update). -pub async fn submit_mail( - agent: &Agent, - canister_id: Principal, - mail: &ReceivedMail, - gateway_flags: &[String], - validate_before_update: bool, -) -> Result<(), SubmitMailError> { - if mail.rcpt_to.is_empty() { - return Err(SubmitMailError::Other( - "internal error: no recipients".to_string(), - )); - } - - let message = parse_rfc5322_message(&mail.raw_message).map_err(SubmitMailError::Other)?; - let from_addr = address_from_smtp_path(&mail.mail_from); - let flags = if gateway_flags.is_empty() { - None - } else { - Some(gateway_flags.to_vec()) - }; - - for to_path in &mail.rcpt_to { - let to_addr = address_from_smtp_path(to_path); - let envelope = Envelope { - from: from_addr.clone(), - to: to_addr, - }; - - if validate_before_update { - let validate_req = SmtpRequest { - message: None, - envelope: Some(envelope.clone()), - gateway_flags: flags.clone(), - }; - let arg = Encode!(&validate_req).map_err(|e| SubmitMailError::Other(e.to_string()))?; - let out = agent - .query(&canister_id, "smtp_request_validate") - .with_arg(arg) - .call() - .await - .map_err(|e| SubmitMailError::Other(agent_err_to_string(e)))?; - let resp = Decode!(&out, SmtpResponse).map_err(|e| SubmitMailError::Other(e.to_string()))?; - match resp { - SmtpResponse::Ok(_) => {} - SmtpResponse::Err(err) => { - let (code, msg) = smtp_line_from_canister_err(&err); - warn!(%canister_id, %to_path, canister_code = %err.code, "smtp_request_validate rejected"); - return Err(SubmitMailError::Rejected { - code, - message: msg, - failed_recipient: to_path.clone(), - }); - } - } - } - - let full = SmtpRequest { - message: Some(message.clone()), - envelope: Some(envelope), - gateway_flags: flags.clone(), - }; - let arg = Encode!(&full).map_err(|e| SubmitMailError::Other(e.to_string()))?; - let out = agent - .update(&canister_id, "smtp_request") - .with_arg(arg) - .call_and_wait() - .await - .map_err(|e| SubmitMailError::Other(agent_err_to_string(e)))?; - let resp = Decode!(&out, SmtpResponse).map_err(|e| SubmitMailError::Other(e.to_string()))?; - match resp { - SmtpResponse::Ok(_) => { - debug!(%canister_id, %to_path, "smtp_request accepted"); - } - SmtpResponse::Err(err) => { - let (code, msg) = smtp_line_from_canister_err(&err); - warn!(%canister_id, %to_path, canister_code = %err.code, "smtp_request rejected"); - return Err(SubmitMailError::Rejected { - code, - message: msg, - failed_recipient: to_path.clone(), - }); - } - } - } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parse_simple_message() { - let raw = b"From: a@b\r\nTo: c@d\r\n\r\nhello"; - let m = parse_rfc5322_message(raw).unwrap(); - assert_eq!(m.headers.len(), 2); - assert_eq!(m.body, b"hello"); - } - - #[test] - fn unfold_continuation() { - let block = "Subject: very\r\n long\r\n line"; - let lines = unfold_header_block(block); - assert_eq!(lines.len(), 1); - assert!(lines[0].contains("very long line")); - } -} diff --git a/ic-bn-lib/src/smtp/inbound/ehlo.rs b/ic-bn-lib/src/smtp/inbound/ehlo.rs index 88612f1..cd4b7d8 100644 --- a/ic-bn-lib/src/smtp/inbound/ehlo.rs +++ b/ic-bn-lib/src/smtp/inbound/ehlo.rs @@ -20,7 +20,8 @@ impl Session { return self.reply("550", "5.5.0", "Invalid EHLO hostname.").await; }; - // If EHLO hostname is already set to the same value - just reply directly + // If EHLO hostname is already set to the same value - just reply directly, + // avoid redundant checks if let Some(v) = &self.data.ehlo_hostname && v == &ehlo_hostname { @@ -77,7 +78,7 @@ impl Session { response.size = self.cfg.max_message_size; // Send STARTTLS cap only if we support TLS & we're not already in TLS mode - if self.tls_info.is_none() && self.cfg.tls_enabled() { + if self.tls_info.is_none() && self.cfg.tls_mode.enabled() { response.capabilities |= EXT_START_TLS; } diff --git a/ic-bn-lib/src/smtp/inbound/mail_from.rs b/ic-bn-lib/src/smtp/inbound/mail_from.rs index 261d105..6a81463 100644 --- a/ic-bn-lib/src/smtp/inbound/mail_from.rs +++ b/ic-bn-lib/src/smtp/inbound/mail_from.rs @@ -30,7 +30,7 @@ impl Session { .await; } - if self.cfg.tls_required() && self.tls_info.is_none() { + if self.cfg.tls_mode.required() && self.tls_info.is_none() { return self .reply( "503", diff --git a/ic-bn-lib/src/smtp/inbound/manager.rs b/ic-bn-lib/src/smtp/inbound/manager.rs index 5a1c0e2..82ee965 100644 --- a/ic-bn-lib/src/smtp/inbound/manager.rs +++ b/ic-bn-lib/src/smtp/inbound/manager.rs @@ -1,6 +1,5 @@ use std::{net::SocketAddr, sync::Arc}; -use derive_new::new; use tokio::io::AsyncWriteExt; use tokio_rustls::server::TlsStream; use tokio_util::sync::CancellationToken; @@ -9,20 +8,19 @@ use tracing::{debug, info}; use crate::{ network::{AsyncReadWrite, tls_handshake}, smtp::inbound::{ - Session, SessionConfig, SessionData, SessionResult, SessionTlsMode, SessionUpgrade, + Session, SessionConfig, SessionData, SessionError, SessionResult, SessionTlsMode, + SessionUpgrade, }, }; /// Manages the lifetime of a single SMTP session. /// -/// Needed because the SMTP session can transition into TLS state +/// It's needed because the SMTP session can transition into TLS state /// which requires external orchestration. -#[derive(new)] pub struct SessionManager; impl SessionManager { pub async fn handle_connection( - &self, stream: S, remote_addr: SocketAddr, params: Arc, @@ -37,29 +35,44 @@ impl SessionManager { } SessionUpgrade::StartTls => { - let log_name = session.to_string(); - match session.into_tls().await { - Ok(mut session) => { - if let Err(e) = session.handle(shutdown_token.child_token()).await { - info!("{session}: error: {e:#}"); - session.stream.shutdown().await.ok(); - } - } - Err(e) => { - info!("{log_name}: TLS handshake failed: {e:#}"); - } - }; + Self::starttls(session, shutdown_token.child_token()).await } }, Err(e) => { - info!("{session}: error: {e:#}, closing connection"); + if !matches!(e, SessionError::Quit) { + info!("{session}: error: {e:#}"); + } + if let Err(e) = session.shutdown().await { debug!("{session}: error closing connection: {e:#}"); }; } } } + + /// Converts session into TLS mode + async fn starttls(session: Session, shutdown_token: CancellationToken) { + let session_name = session.to_string(); + + match session.into_tls().await { + Ok(mut session) => { + if let Err(e) = session.handle(shutdown_token.child_token()).await { + if !matches!(e, SessionError::Quit) { + info!("{session}: error: {e:#}"); + } + + if let Err(e) = session.shutdown().await { + debug!("{session}: error closing connection: {e:#}"); + }; + } + } + + Err(e) => { + info!("{session_name}: TLS handshake failed: {e:#}"); + } + }; + } } impl Session { diff --git a/ic-bn-lib/src/smtp/inbound/mod.rs b/ic-bn-lib/src/smtp/inbound/mod.rs index de6b4f9..20c5e59 100644 --- a/ic-bn-lib/src/smtp/inbound/mod.rs +++ b/ic-bn-lib/src/smtp/inbound/mod.rs @@ -74,6 +74,16 @@ pub enum SessionTlsMode { Required(Arc), } +impl SessionTlsMode { + pub const fn enabled(&self) -> bool { + matches!(self, Self::Allowed(_) | Self::Required(_)) + } + + pub const fn required(&self) -> bool { + matches!(self, Self::Required(_)) + } +} + /// SMTP session config pub struct SessionConfig { hostname: String, @@ -126,17 +136,6 @@ impl SessionConfig { delivery_agent: Arc::new(DummyDeliveryAgent), } } - - pub const fn tls_enabled(&self) -> bool { - matches!( - self.tls_mode, - SessionTlsMode::Allowed(_) | SessionTlsMode::Required(_) - ) - } - - pub const fn tls_required(&self) -> bool { - matches!(self.tls_mode, SessionTlsMode::Required(_)) - } } /// SMTP session state @@ -207,7 +206,12 @@ pub struct Session { impl Display for Session { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "SMTP/Session({})", self.remote_ip) + write!( + f, + "SMTP/Session({}){}", + self.remote_ip, + if self.tls_info.is_some() { "/TLS" } else { "" } + ) } } @@ -298,7 +302,7 @@ mod tests { cfg.max_message_size = 512; cfg.helo_delay = helo_delay; cfg.max_messages_per_session = 3; - cfg.max_session_data = 1024; + cfg.max_session_data = 8192; cfg.max_recipients = 3; Session::new(IpAddr::from_str("1.1.1.1").unwrap(), stream, Arc::new(cfg)) @@ -407,7 +411,9 @@ mod tests { .read(b"56789") .write(b"250 2.6.0 Chunk accepted.\r\n") .read(b"BDAT 10\r\n") - .read(b"9876543210") + .read(b"987") + .read(b"654") + .read(b"3210") .write(b"250 2.6.0 Chunk accepted.\r\n") .read(b"BDAT 10 LAST\r\n") .read(b"0123456789") @@ -616,6 +622,8 @@ mod tests { .read(b"FOO\r\n") .write(b"500 5.5.1 Invalid command.\r\n") .read(b"FOO\r\n") + .write(b"500 5.5.1 Invalid command.\r\n") + .read(b"FOO\r\n") .write(b"452 4.3.2 Too many errors.\r\n") .build(); @@ -627,13 +635,50 @@ mod tests { )); } + #[tokio::test] + async fn test_request_too_large() { + let stream = tokio_test::io::Builder::new() + .write(b"220 test ESMTP IC SMTP Gateway\r\n") + .read(format!("EHLO {}", "1".repeat(2048)).as_bytes()) + .read(format!("{}\r\n", "1".repeat(2048)).as_bytes()) + .write(b"554 5.3.4 Line is too long.\r\n") + .read(b"QUIT\r\n") + .write(b"221 2.0.0 Bye.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::Quit + )); + } + + #[tokio::test] + async fn test_max_session_transfer_quota() { + let stream = tokio_test::io::Builder::new() + .write(b"220 test ESMTP IC SMTP Gateway\r\n") + .read(format!("EHLO {}\r\n", "1".repeat(8192)).as_bytes()) + .write(b"452 4.7.28 Session transfer quota exceeded.\r\n") + .build(); + + let mut session = create_session(stream, None); + + assert!(matches!( + session.handle(CancellationToken::new()).await.unwrap_err(), + SessionError::TransferQuotaExceeded(_) + )); + } + #[tokio::test] async fn test_starttls() { rustls::crypto::ring::default_provider() .install_default() - .unwrap(); + .ok(); + // Use an in-memory pipe let (stream1, mut stream2) = duplex(128); + let rustls_server_cfg = ServerConfig::builder() .with_no_client_auth() .with_cert_resolver(Arc::new( @@ -643,16 +688,14 @@ mod tests { let mut cfg = SessionConfig::new("test"); cfg.tls_mode = SessionTlsMode::Required(Arc::new(rustls_server_cfg)); - let session_manager = SessionManager::new(); tokio::spawn(async move { - session_manager - .handle_connection( - stream1, - SocketAddr::from_str("1.1.1.1:123").unwrap(), - Arc::new(cfg), - CancellationToken::new(), - ) - .await; + SessionManager::handle_connection( + stream1, + SocketAddr::from_str("1.1.1.1:123").unwrap(), + Arc::new(cfg), + CancellationToken::new(), + ) + .await; }); let mut buf = vec![0; 256]; @@ -665,7 +708,7 @@ mod tests { let r = stream2.read(&mut buf).await.unwrap(); assert_eq!(&buf[..r], b"250-test you had me at EHLO\r\n250-STARTTLS\r\n250-SMTPUTF8\r\n250-ENHANCEDSTATUSCODES\r\n250-CHUNKING\r\n250 8BITMIME\r\n"); - // Make sure TLS is required by server + // Make sure TLS is required by the server due to SessionTlsMode::Required stream2.write_all(b"MAIL FROM:\r\n").await.unwrap(); let r = stream2.read(&mut buf).await.unwrap(); assert_eq!( @@ -673,6 +716,7 @@ mod tests { b"503 5.5.1 TLS is required to submit mail on this server.\r\n" ); + // Fire up TLS handshake stream2.write_all(b"STARTTLS\r\n").await.unwrap(); let r = stream2.read(&mut buf).await.unwrap(); assert_eq!(&buf[..r], b"220 2.0.0 Ready to start TLS.\r\n"); @@ -687,7 +731,7 @@ mod tests { .await .unwrap(); - // Make sure there's NO 250-STARTTLS in EHLO anymore inside TLS session + // Make sure there's no 250-STARTTLS in EHLO anymore inside TLS session tls_stream.write_all(b"EHLO foo.bar\r\n").await.unwrap(); let r = tls_stream.read(&mut buf).await.unwrap(); assert_eq!(&buf[..r], b"250-test you had me at EHLO\r\n250-SMTPUTF8\r\n250-ENHANCEDSTATUSCODES\r\n250-CHUNKING\r\n250 8BITMIME\r\n"); @@ -701,5 +745,10 @@ mod tests { tls_stream.write_all(b"MAIL FROM:\r\n").await.unwrap(); let r = tls_stream.read(&mut buf).await.unwrap(); assert_eq!(&buf[..r], b"250 2.1.0 OK\r\n"); + + // All good + tls_stream.write_all(b"QUIT\r\n").await.unwrap(); + let r = tls_stream.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..r], b"221 2.0.0 Bye.\r\n"); } } diff --git a/ic-bn-lib/src/smtp/inbound/session.rs b/ic-bn-lib/src/smtp/inbound/session.rs index c331541..e3bec48 100644 --- a/ic-bn-lib/src/smtp/inbound/session.rs +++ b/ic-bn-lib/src/smtp/inbound/session.rs @@ -176,11 +176,12 @@ impl Session { Ok(()) } + /// Main SMTP state machine async fn ingest(&mut self, bytes: &[u8]) -> SessionResult { debug!("{self}: Read: {}", String::from_utf8_lossy(bytes)); // Check if we are over session transfer quota - if self.counters.bytes_ingested + bytes.len() >= self.cfg.max_session_data { + if self.counters.bytes_ingested + bytes.len() > self.cfg.max_session_data { self.reply("452", "4.7.28", "Session transfer quota exceeded.") .await?; return Err(SessionError::TransferQuotaExceeded( @@ -198,7 +199,7 @@ impl Session { } // Check if we are over error limit - if self.counters.errors >= self.cfg.max_errors { + if self.counters.errors > self.cfg.max_errors { self.reply("452", "4.3.2", "Too many errors.").await?; return Err(SessionError::TooManyErrors); } @@ -241,7 +242,7 @@ impl Session { chunk_size, )) } else { - // Allocate the needed capacity for the chunk + // Preallocate the needed capacity for the chunk if need be let free = self.data.message.capacity() - self.data.message.len(); if free < chunk_size { @@ -255,7 +256,7 @@ impl Session { if self.tls_info.is_some() { self.reply("504", "5.7.4", "Already in TLS mode.").await?; self.counters.errors += 1; - } else if !self.cfg.tls_enabled() { + } else if !self.cfg.tls_mode.enabled() { self.reply("502", "5.7.0", "TLS not available.").await?; self.counters.errors += 1; } else { @@ -268,13 +269,15 @@ impl Session { self.handle_request(other_request).await?; } }, - // In case of NeedsMoreData error we just leave - // and wait for new data to be ingested - Err(SmtpError::NeedsMoreData { .. }) => break, Err(SmtpError::ResponseTooLong) => { state = SessionState::RequestTooLarge(DummyLineReceiver::default()); continue; } + // In case of NeedsMoreData error we just leave + // and wait for new data to be ingested + Err(SmtpError::NeedsMoreData { .. }) => break, + + // Handle other errors separately Err(e) => { self.handle_error(e).await?; } @@ -397,7 +400,7 @@ impl Session { id, ehlo_hostname: self.data.ehlo_hostname.clone().unwrap(), mail_from: self.data.mail_from.take().unwrap(), - rcpt_to: self.data.rcpt_to.drain(..).collect(), + rcpt_to: std::mem::take(&mut self.data.rcpt_to), body: std::mem::take(&mut self.data.message), }; diff --git a/ic-bn-lib/src/smtp/server.rs b/ic-bn-lib/src/smtp/server.rs index 4fa2bfc..7e45316 100644 --- a/ic-bn-lib/src/smtp/server.rs +++ b/ic-bn-lib/src/smtp/server.rs @@ -29,15 +29,14 @@ impl Display for Server { } impl Server { - pub fn new(listen_addr: SocketAddr, cfg: SessionConfig) -> Result { + /// Creates a new `Server` to listen on `listen_addr` + pub fn new(listen_addr: SocketAddr, cfg: SessionConfig) -> io::Result { let listener = listen_tcp(listen_addr, ListenerOpts::default())?; Self::new_with_listener(listener, cfg) } - pub fn new_with_listener( - listener: TcpListener, - params: SessionConfig, - ) -> Result { + /// Creates a new `Server` from a pre-built `TcpListener` + pub fn new_with_listener(listener: TcpListener, params: SessionConfig) -> io::Result { Ok(Self { listen_addr: listener.local_addr()?, listener, @@ -46,7 +45,7 @@ impl Server { }) } - async fn handle_accept( + async fn handle_connection( &self, res: io::Result<(TcpStream, SocketAddr)>, token: &CancellationToken, @@ -55,11 +54,9 @@ impl Server { Ok((stream, addr)) => { info!("{self}: New connection from {addr}"); - let (manager, params, token) = - (SessionManager, self.params.clone(), token.child_token()); - + let (params, token) = (self.params.clone(), token.child_token()); self.tracker.spawn(async move { - manager.handle_connection(stream, addr, params, token).await; + SessionManager::handle_connection(stream, addr, params, token).await; }); } @@ -70,18 +67,22 @@ impl Server { } } - pub async fn serve(&self, token: CancellationToken) -> SessionResult<()> { + /// Main connection handling loop + pub async fn serve(&self, token: CancellationToken) -> io::Result<()> { loop { select! { res = self.listener.accept() => { - self.handle_accept(res, &token).await; + self.handle_connection(res, &token).await; } () = token.cancelled() => { + warn!("{self}: Server shutting down, closing connections"); + self.tracker.close(); - if self.tracker.wait().timeout(Duration::from_secs(60)).await.is_err() { + if self.tracker.wait().timeout(Duration::from_secs(30)).await.is_err() { warn!("{self}: Timed out waiting for connections to close"); } + break; } } From 80c8405eecf52ac9c4748698a112e214ada069b2 Mon Sep 17 00:00:00 2001 From: Igor Novgorodov Date: Mon, 18 May 2026 17:19:41 +0200 Subject: [PATCH 8/8] Address comments --- ic-bn-lib/src/smtp/inbound/manager.rs | 4 +++- ic-bn-lib/src/smtp/inbound/session.rs | 18 +++--------------- ic-bn-lib/src/smtp/server.rs | 6 +++--- 3 files changed, 9 insertions(+), 19 deletions(-) diff --git a/ic-bn-lib/src/smtp/inbound/manager.rs b/ic-bn-lib/src/smtp/inbound/manager.rs index 82ee965..d8f8a8d 100644 --- a/ic-bn-lib/src/smtp/inbound/manager.rs +++ b/ic-bn-lib/src/smtp/inbound/manager.rs @@ -82,7 +82,9 @@ impl Session { // It's better to panic otherwise. let tls_config = match &self.cfg.tls_mode { SessionTlsMode::Allowed(v) | SessionTlsMode::Required(v) => v.clone(), - SessionTlsMode::Disabled => unreachable!(), + SessionTlsMode::Disabled => { + unreachable!("Session::into_tls() called with TLS disabled") + } }; let (stream, tls_info) = tls_handshake(tls_config, self.stream).await?; diff --git a/ic-bn-lib/src/smtp/inbound/session.rs b/ic-bn-lib/src/smtp/inbound/session.rs index e3bec48..55ae28f 100644 --- a/ic-bn-lib/src/smtp/inbound/session.rs +++ b/ic-bn-lib/src/smtp/inbound/session.rs @@ -1,5 +1,6 @@ use std::{ borrow::Cow, + io::Write, time::{Duration, Instant}, }; @@ -48,21 +49,8 @@ impl Session { "Reply longer than supported - increase MAX_REPLY_LEN" ); - // Poor man's `format!` let mut buf = [0; MAX_REPLY_LEN]; - let (mut i, mut j) = (0, code.len()); - buf[i..j].copy_from_slice(code.as_bytes()); - buf[j] = b' '; - i += code.len() + 1; - j += ext.len() + 1; - buf[i..j].copy_from_slice(ext.as_bytes()); - buf[j] = b' '; - i += ext.len() + 1; - j += msg.len() + 1; - buf[i..j].copy_from_slice(msg.as_bytes()); - buf[j] = b'\r'; - buf[j + 1] = b'\n'; - + write!(&mut buf[..], "{code} {ext} {msg}\r\n")?; self.write(&buf[..len]).await } @@ -90,7 +78,7 @@ impl Session { // If something comes in - then the client isn't respecting the protocol, // we consider him malicious and drop the connection. if let Some(v) = self.cfg.helo_delay { - let mut buf = vec![0; 128]; + let mut buf = [0; 256]; match self.stream.read(&mut buf).timeout(v).await { Ok(Ok(bytes_read)) => { if bytes_read > 0 { diff --git a/ic-bn-lib/src/smtp/server.rs b/ic-bn-lib/src/smtp/server.rs index 7e45316..0e6d627 100644 --- a/ic-bn-lib/src/smtp/server.rs +++ b/ic-bn-lib/src/smtp/server.rs @@ -55,9 +55,9 @@ impl Server { info!("{self}: New connection from {addr}"); let (params, token) = (self.params.clone(), token.child_token()); - self.tracker.spawn(async move { - SessionManager::handle_connection(stream, addr, params, token).await; - }); + self.tracker.spawn(SessionManager::handle_connection( + stream, addr, params, token, + )); } Err(e) => {