diff --git a/crates/test-programs/src/bin/cli_no_tcp.rs b/crates/test-programs/src/bin/cli_no_tcp.rs index 58ca3de8958d..a59a0559c035 100644 --- a/crates/test-programs/src/bin/cli_no_tcp.rs +++ b/crates/test-programs/src/bin/cli_no_tcp.rs @@ -1,28 +1,12 @@ //! This test assumes that it will be run without tcp support enabled -use test_programs::wasi::sockets::{ - network::IpAddress, - tcp::{ErrorCode, IpAddressFamily, IpSocketAddress, Network, TcpSocket}, -}; -fn main() { - let net = Network::default(); - let family = IpAddressFamily::Ipv4; - let remote1 = IpSocketAddress::new(IpAddress::new_loopback(family), 4321); - let sock = TcpSocket::new(family).unwrap(); - - let bind = sock.blocking_bind(&net, remote1); - eprintln!("Result of binding: {bind:?}"); - assert!(matches!(bind, Err(ErrorCode::AccessDenied))); +#![deny(warnings)] - let listen = sock.blocking_listen(); - eprintln!("Result of listen: {listen:?}"); - assert!(matches!(listen, Err(ErrorCode::AccessDenied))); +use test_programs::wasi::sockets::tcp::{ErrorCode, IpAddressFamily, TcpSocket}; - let connect = sock.blocking_connect(&net, remote1); - eprintln!("Result of connect: {connect:?}"); - assert!(matches!(connect, Err(ErrorCode::AccessDenied))); - - let accept = sock.blocking_accept(); - eprintln!("Result of accept: {accept:?}"); - assert!(matches!(accept, Err(ErrorCode::AccessDenied))); +fn main() { + assert!(matches!( + TcpSocket::new(IpAddressFamily::Ipv4), + Err(ErrorCode::AccessDenied) + )); } diff --git a/crates/wasi/src/p2/bindings.rs b/crates/wasi/src/p2/bindings.rs index 7fb5e7b136da..be56c36f4ec4 100644 --- a/crates/wasi/src/p2/bindings.rs +++ b/crates/wasi/src/p2/bindings.rs @@ -393,7 +393,7 @@ mod async_io { // Configure all other resources to be concrete types defined in // this crate "wasi:sockets/network/network": crate::p2::network::Network, - "wasi:sockets/tcp/tcp-socket": crate::p2::tcp::TcpSocket, + "wasi:sockets/tcp/tcp-socket": crate::sockets::TcpSocket, "wasi:sockets/udp/udp-socket": crate::sockets::UdpSocket, "wasi:sockets/udp/incoming-datagram-stream": crate::p2::udp::IncomingDatagramStream, "wasi:sockets/udp/outgoing-datagram-stream": crate::p2::udp::OutgoingDatagramStream, diff --git a/crates/wasi/src/p2/host/tcp.rs b/crates/wasi/src/p2/host/tcp.rs index feb22c72b351..f806db07def9 100644 --- a/crates/wasi/src/p2/host/tcp.rs +++ b/crates/wasi/src/p2/host/tcp.rs @@ -1,11 +1,10 @@ -use crate::p2::SocketResult; use crate::p2::bindings::{ - sockets::network::{IpAddressFamily, IpSocketAddress, Network}, + sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network}, sockets::tcp::{self, ShutdownType}, }; -use crate::sockets::{SocketAddrUse, SocketAddressFamily, WasiSocketsCtxView}; +use crate::p2::{Pollable, SocketResult}; +use crate::sockets::{SocketAddrUse, TcpSocket, WasiSocketsCtxView}; use std::net::SocketAddr; -use std::time::Duration; use wasmtime::component::Resource; use wasmtime_wasi_io::{ poll::DynPollable, @@ -17,11 +16,10 @@ impl tcp::Host for WasiSocketsCtxView<'_> {} impl crate::p2::host::tcp::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { async fn start_bind( &mut self, - this: Resource, + this: Resource, network: Resource, local_address: IpSocketAddress, ) -> SocketResult<()> { - self.ctx.allowed_network_uses.check_allowed_tcp()?; let network = self.table.get(&network)?; let local_address: SocketAddr = local_address.into(); @@ -36,19 +34,18 @@ impl crate::p2::host::tcp::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { Ok(()) } - fn finish_bind(&mut self, this: Resource) -> SocketResult<()> { + fn finish_bind(&mut self, this: Resource) -> SocketResult<()> { let socket = self.table.get_mut(&this)?; - - socket.finish_bind() + socket.finish_bind()?; + Ok(()) } async fn start_connect( &mut self, - this: Resource, + this: Resource, network: Resource, remote_address: IpSocketAddress, ) -> SocketResult<()> { - self.ctx.allowed_network_uses.check_allowed_tcp()?; let network = self.table.get(&network)?; let remote_address: SocketAddr = remote_address.into(); @@ -58,49 +55,56 @@ impl crate::p2::host::tcp::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { .await?; // Start connection - self.table.get_mut(&this)?.start_connect(remote_address)?; + let socket = self.table.get_mut(&this)?; + let future = socket + .start_connect(&remote_address)? + .connect(remote_address); + socket.set_pending_connect(future)?; Ok(()) } fn finish_connect( &mut self, - this: Resource, + this: Resource, ) -> SocketResult<(Resource, Resource)> { let socket = self.table.get_mut(&this)?; - let (input, output) = socket.finish_connect()?; - - let input_stream = self.table.push_child(input, &this)?; - let output_stream = self.table.push_child(output, &this)?; - - Ok((input_stream, output_stream)) + let result = socket + .take_pending_connect()? + .ok_or(ErrorCode::WouldBlock)?; + socket.finish_connect(result)?; + let (input, output) = socket.p2_streams()?; + let input = self.table.push_child(input, &this)?; + let output = self.table.push_child(output, &this)?; + Ok((input, output)) } - fn start_listen(&mut self, this: Resource) -> SocketResult<()> { - self.ctx.allowed_network_uses.check_allowed_tcp()?; + fn start_listen(&mut self, this: Resource) -> SocketResult<()> { let socket = self.table.get_mut(&this)?; - socket.start_listen() + socket.start_listen()?; + Ok(()) } - fn finish_listen(&mut self, this: Resource) -> SocketResult<()> { + fn finish_listen(&mut self, this: Resource) -> SocketResult<()> { let socket = self.table.get_mut(&this)?; - socket.finish_listen() + socket.finish_listen()?; + Ok(()) } fn accept( &mut self, - this: Resource, + this: Resource, ) -> SocketResult<( - Resource, + Resource, Resource, Resource, )> { - self.ctx.allowed_network_uses.check_allowed_tcp()?; let socket = self.table.get_mut(&this)?; - let (tcp_socket, input, output) = socket.accept()?; + let mut tcp_socket = socket.accept()?.ok_or(ErrorCode::WouldBlock)?; + let (input, output) = tcp_socket.p2_streams()?; let tcp_socket = self.table.push(tcp_socket)?; let input_stream = self.table.push_child(input, &tcp_socket)?; @@ -109,19 +113,17 @@ impl crate::p2::host::tcp::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { Ok((tcp_socket, input_stream, output_stream)) } - fn local_address(&mut self, this: Resource) -> SocketResult { + fn local_address(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - - socket.local_address().map(Into::into) + Ok(socket.local_address()?.into()) } - fn remote_address(&mut self, this: Resource) -> SocketResult { + fn remote_address(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - - socket.remote_address().map(Into::into) + Ok(socket.remote_address()?.into()) } - fn is_listening(&mut self, this: Resource) -> Result { + fn is_listening(&mut self, this: Resource) -> Result { let socket = self.table.get(&this)?; Ok(socket.is_listening()) @@ -129,135 +131,122 @@ impl crate::p2::host::tcp::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { fn address_family( &mut self, - this: Resource, + this: Resource, ) -> Result { let socket = self.table.get(&this)?; - - match socket.address_family() { - SocketAddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4), - SocketAddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6), - } + Ok(socket.address_family().into()) } fn set_listen_backlog_size( &mut self, - this: Resource, + this: Resource, value: u64, ) -> SocketResult<()> { let socket = self.table.get_mut(&this)?; - - // Silently clamp backlog size. This is OK for us to do, because operating systems do this too. - let value = value.try_into().unwrap_or(u32::MAX); - - socket.set_listen_backlog_size(value) + socket.set_listen_backlog_size(value)?; + Ok(()) } - fn keep_alive_enabled(&mut self, this: Resource) -> SocketResult { + fn keep_alive_enabled(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - socket.keep_alive_enabled() + Ok(socket.keep_alive_enabled()?) } fn set_keep_alive_enabled( &mut self, - this: Resource, + this: Resource, value: bool, ) -> SocketResult<()> { let socket = self.table.get(&this)?; - socket.set_keep_alive_enabled(value) + socket.set_keep_alive_enabled(value)?; + Ok(()) } - fn keep_alive_idle_time(&mut self, this: Resource) -> SocketResult { + fn keep_alive_idle_time(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - Ok(socket.keep_alive_idle_time()?.as_nanos() as u64) + Ok(socket.keep_alive_idle_time()?) } fn set_keep_alive_idle_time( &mut self, - this: Resource, + this: Resource, value: u64, ) -> SocketResult<()> { let socket = self.table.get_mut(&this)?; - socket.set_keep_alive_idle_time(value) + socket.set_keep_alive_idle_time(value)?; + Ok(()) } - fn keep_alive_interval(&mut self, this: Resource) -> SocketResult { + fn keep_alive_interval(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - Ok(socket.keep_alive_interval()?.as_nanos() as u64) + Ok(socket.keep_alive_interval()?) } fn set_keep_alive_interval( &mut self, - this: Resource, + this: Resource, value: u64, ) -> SocketResult<()> { let socket = self.table.get(&this)?; - socket.set_keep_alive_interval(Duration::from_nanos(value)) + socket.set_keep_alive_interval(value)?; + Ok(()) } - fn keep_alive_count(&mut self, this: Resource) -> SocketResult { + fn keep_alive_count(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - socket.keep_alive_count() + Ok(socket.keep_alive_count()?) } - fn set_keep_alive_count( - &mut self, - this: Resource, - value: u32, - ) -> SocketResult<()> { + fn set_keep_alive_count(&mut self, this: Resource, value: u32) -> SocketResult<()> { let socket = self.table.get(&this)?; - socket.set_keep_alive_count(value) + socket.set_keep_alive_count(value)?; + Ok(()) } - fn hop_limit(&mut self, this: Resource) -> SocketResult { + fn hop_limit(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - socket.hop_limit() + Ok(socket.hop_limit()?) } - fn set_hop_limit(&mut self, this: Resource, value: u8) -> SocketResult<()> { + fn set_hop_limit(&mut self, this: Resource, value: u8) -> SocketResult<()> { let socket = self.table.get_mut(&this)?; - socket.set_hop_limit(value) + socket.set_hop_limit(value)?; + Ok(()) } - fn receive_buffer_size(&mut self, this: Resource) -> SocketResult { + fn receive_buffer_size(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - Ok(socket.receive_buffer_size()?) } fn set_receive_buffer_size( &mut self, - this: Resource, + this: Resource, value: u64, ) -> SocketResult<()> { let socket = self.table.get_mut(&this)?; - socket.set_receive_buffer_size(value) + socket.set_receive_buffer_size(value)?; + Ok(()) } - fn send_buffer_size(&mut self, this: Resource) -> SocketResult { + fn send_buffer_size(&mut self, this: Resource) -> SocketResult { let socket = self.table.get(&this)?; - Ok(socket.send_buffer_size()?) } - fn set_send_buffer_size( - &mut self, - this: Resource, - value: u64, - ) -> SocketResult<()> { + fn set_send_buffer_size(&mut self, this: Resource, value: u64) -> SocketResult<()> { let socket = self.table.get_mut(&this)?; - socket.set_send_buffer_size(value) + socket.set_send_buffer_size(value)?; + Ok(()) } - fn subscribe( - &mut self, - this: Resource, - ) -> anyhow::Result> { + fn subscribe(&mut self, this: Resource) -> anyhow::Result> { wasmtime_wasi_io::poll::subscribe(self.table, this) } fn shutdown( &mut self, - this: Resource, + this: Resource, shutdown_type: ShutdownType, ) -> SocketResult<()> { let socket = self.table.get(&this)?; @@ -267,10 +256,13 @@ impl crate::p2::host::tcp::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { ShutdownType::Send => std::net::Shutdown::Write, ShutdownType::Both => std::net::Shutdown::Both, }; - socket.shutdown(how) + + let state = socket.p2_streaming_state()?; + state.shutdown(how)?; + Ok(()) } - fn drop(&mut self, this: Resource) -> Result<(), anyhow::Error> { + fn drop(&mut self, this: Resource) -> Result<(), anyhow::Error> { // As in the filesystem implementation, we assume closing a socket // doesn't block. let dropped = self.table.delete(this)?; @@ -280,6 +272,13 @@ impl crate::p2::host::tcp::tcp::HostTcpSocket for WasiSocketsCtxView<'_> { } } +#[async_trait::async_trait] +impl Pollable for TcpSocket { + async fn ready(&mut self) { + ::ready(self).await; + } +} + pub mod sync { use crate::p2::{ SocketError, diff --git a/crates/wasi/src/p2/host/tcp_create_socket.rs b/crates/wasi/src/p2/host/tcp_create_socket.rs index 211648e0ac00..10fcc75edc2b 100644 --- a/crates/wasi/src/p2/host/tcp_create_socket.rs +++ b/crates/wasi/src/p2/host/tcp_create_socket.rs @@ -1,7 +1,6 @@ use crate::p2::SocketResult; use crate::p2::bindings::{sockets::network::IpAddressFamily, sockets::tcp_create_socket}; -use crate::p2::tcp::TcpSocket; -use crate::sockets::WasiSocketsCtxView; +use crate::sockets::{SocketAddressFamily, TcpSocket, WasiSocketsCtxView}; use wasmtime::component::Resource; impl tcp_create_socket::Host for WasiSocketsCtxView<'_> { @@ -9,8 +8,17 @@ impl tcp_create_socket::Host for WasiSocketsCtxView<'_> { &mut self, address_family: IpAddressFamily, ) -> SocketResult> { - let socket = TcpSocket::new(address_family.into())?; + let socket = TcpSocket::new(self.ctx, address_family.into())?; let socket = self.table.push(socket)?; Ok(socket) } } + +impl From for SocketAddressFamily { + fn from(family: IpAddressFamily) -> SocketAddressFamily { + match family { + IpAddressFamily::Ipv4 => Self::Ipv4, + IpAddressFamily::Ipv6 => Self::Ipv6, + } + } +} diff --git a/crates/wasi/src/p2/mod.rs b/crates/wasi/src/p2/mod.rs index ea2a9f48f9d6..857cba2d6eb8 100644 --- a/crates/wasi/src/p2/mod.rs +++ b/crates/wasi/src/p2/mod.rs @@ -243,6 +243,7 @@ mod write_stream; pub use self::filesystem::{FsError, FsResult}; pub use self::network::{Network, SocketError, SocketResult}; pub use self::stdio::IsATTY; +pub(crate) use tcp::P2TcpStreamingState; // These contents of wasmtime-wasi-io are re-exported by this module for compatibility: // they were originally defined in this module before being factored out, and many // users of this module depend on them at these names. diff --git a/crates/wasi/src/p2/network.rs b/crates/wasi/src/p2/network.rs index 8534d0320485..f1531c32a319 100644 --- a/crates/wasi/src/p2/network.rs +++ b/crates/wasi/src/p2/network.rs @@ -49,6 +49,7 @@ impl From for ErrorCode { crate::sockets::util::ErrorCode::ConnectionAborted => Self::ConnectionAborted, crate::sockets::util::ErrorCode::DatagramTooLarge => Self::DatagramTooLarge, crate::sockets::util::ErrorCode::NotInProgress => Self::NotInProgress, + crate::sockets::util::ErrorCode::ConcurrencyConflict => Self::ConcurrencyConflict, } } } diff --git a/crates/wasi/src/p2/tcp.rs b/crates/wasi/src/p2/tcp.rs index 1f7dc1b3f391..019c3c6b1f00 100644 --- a/crates/wasi/src/p2/tcp.rs +++ b/crates/wasi/src/p2/tcp.rs @@ -1,666 +1,54 @@ -use crate::p2::bindings::sockets::tcp::ErrorCode; use crate::p2::{ DynInputStream, DynOutputStream, InputStream, OutputStream, Pollable, SocketError, SocketResult, StreamError, }; -use crate::runtime::{AbortOnDropJoinHandle, with_ambient_tokio_runtime}; -use crate::sockets::util::{ - get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address, - is_valid_unicast_address, receive_buffer_size, send_buffer_size, set_keep_alive_count, - set_keep_alive_idle_time, set_keep_alive_interval, set_receive_buffer_size, - set_send_buffer_size, set_unicast_hop_limit, tcp_bind, -}; -use crate::sockets::{DEFAULT_TCP_BACKLOG, SocketAddressFamily}; +use crate::runtime::AbortOnDropJoinHandle; +use crate::sockets::TcpSocket; use anyhow::Result; -use cap_net_ext::AddressFamily; -use futures::Future; use io_lifetimes::AsSocketlike; -use io_lifetimes::views::SocketlikeView; use rustix::io::Errno; -use rustix::net::sockopt; use std::io; use std::mem; -use std::net::{Shutdown, SocketAddr}; -use std::pin::Pin; +use std::net::Shutdown; use std::sync::Arc; -use std::task::{Poll, Waker}; use tokio::sync::Mutex; -/// The state of a TCP socket. -/// -/// This represents the various states a socket can be in during the -/// activities of binding, listening, accepting, and connecting. -enum TcpState { - /// The initial state for a newly-created socket. - Default(tokio::net::TcpSocket), - - /// Binding started via `start_bind`. - BindStarted(tokio::net::TcpSocket), - - /// Binding finished via `finish_bind`. The socket has an address but - /// is not yet listening for connections. - Bound(tokio::net::TcpSocket), - - /// Listening started via `listen_start`. - ListenStarted(tokio::net::TcpSocket), - - /// The socket is now listening and waiting for an incoming connection. - Listening { - listener: tokio::net::TcpListener, - pending_accept: Option>, - }, - - /// An outgoing connection is started via `start_connect`. - Connecting(Pin> + Send>>), - - /// An outgoing connection is ready to be established. - ConnectReady(io::Result), - - /// An outgoing connection has been established. - Connected { - stream: Arc, - - // WASI is single threaded, so in practice these Mutexes should never be contended: - reader: Arc>, - writer: Arc>, - }, - - Closed, -} - -impl std::fmt::Debug for TcpState { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Default(_) => f.debug_tuple("Default").finish(), - Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(), - Self::Bound(_) => f.debug_tuple("Bound").finish(), - Self::ListenStarted(_) => f.debug_tuple("ListenStarted").finish(), - Self::Listening { pending_accept, .. } => f - .debug_struct("Listening") - .field("pending_accept", pending_accept) - .finish(), - Self::Connecting(_) => f.debug_tuple("Connecting").finish(), - Self::ConnectReady(_) => f.debug_tuple("ConnectReady").finish(), - Self::Connected { .. } => f.debug_tuple("Connected").finish(), - Self::Closed => write!(f, "Closed"), - } - } -} - -/// A host TCP socket, plus associated bookkeeping. -pub struct TcpSocket { - /// The current state in the bind/listen/accept/connect progression. - tcp_state: TcpState, - - /// The desired listen queue size. - listen_backlog_size: u32, - - family: SocketAddressFamily, - - // The socket options below are not automatically inherited from the listener - // on all platforms. So we keep track of which options have been explicitly - // set and manually apply those values to newly accepted clients. - #[cfg(target_os = "macos")] - receive_buffer_size: Option, - #[cfg(target_os = "macos")] - send_buffer_size: Option, - #[cfg(target_os = "macos")] - hop_limit: Option, - #[cfg(target_os = "macos")] - keep_alive_idle_time: Option, -} - impl TcpSocket { - /// Create a new socket in the given family. - pub(crate) fn new(family: AddressFamily) -> io::Result { - with_ambient_tokio_runtime(|| { - let (socket, family) = match family { - AddressFamily::Ipv4 => { - let socket = tokio::net::TcpSocket::new_v4()?; - (socket, SocketAddressFamily::Ipv4) - } - AddressFamily::Ipv6 => { - let socket = tokio::net::TcpSocket::new_v6()?; - sockopt::set_ipv6_v6only(&socket, true)?; - (socket, SocketAddressFamily::Ipv6) - } - }; - - Self::from_state(TcpState::Default(socket), family) - }) - } - - /// Create a `TcpSocket` from an existing socket. - fn from_state(state: TcpState, family: SocketAddressFamily) -> io::Result { - Ok(Self { - tcp_state: state, - listen_backlog_size: DEFAULT_TCP_BACKLOG, - family, - #[cfg(target_os = "macos")] - receive_buffer_size: None, - #[cfg(target_os = "macos")] - send_buffer_size: None, - #[cfg(target_os = "macos")] - hop_limit: None, - #[cfg(target_os = "macos")] - keep_alive_idle_time: None, - }) - } - - fn as_std_view(&self) -> SocketResult> { - use crate::p2::bindings::sockets::network::ErrorCode; - - match &self.tcp_state { - TcpState::Default(socket) | TcpState::Bound(socket) => { - Ok(socket.as_socketlike_view::()) - } - TcpState::Connected { stream, .. } => { - Ok(stream.as_socketlike_view::()) - } - TcpState::Listening { listener, .. } => { - Ok(listener.as_socketlike_view::()) - } - - TcpState::BindStarted(..) - | TcpState::ListenStarted(..) - | TcpState::Connecting(..) - | TcpState::ConnectReady(..) - | TcpState::Closed => Err(ErrorCode::InvalidState.into()), - } - } -} - -impl TcpSocket { - pub(crate) fn start_bind(&mut self, local_address: SocketAddr) -> Result<(), ErrorCode> { - let tokio_socket = match &self.tcp_state { - TcpState::Default(socket) => socket, - TcpState::BindStarted(..) => return Err(Errno::ALREADY.into()), - _ => return Err(ErrorCode::InvalidState), - }; - - if !is_valid_unicast_address(local_address.ip()) - || !is_valid_address_family(local_address.ip(), self.family) - { - return Err(ErrorCode::InvalidArgument); - }; - - { - tcp_bind(&tokio_socket, local_address)?; - - self.tcp_state = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) { - TcpState::Default(socket) => TcpState::BindStarted(socket), - _ => unreachable!(), - }; - - Ok(()) - } - } - - pub(crate) fn finish_bind(&mut self) -> SocketResult<()> { - match std::mem::replace(&mut self.tcp_state, TcpState::Closed) { - TcpState::BindStarted(socket) => { - self.tcp_state = TcpState::Bound(socket); - Ok(()) - } - current_state => { - // Reset the state so that the outside world doesn't see this socket as closed - self.tcp_state = current_state; - Err(ErrorCode::NotInProgress.into()) - } - } - } - - pub(crate) fn start_connect(&mut self, remote_address: SocketAddr) -> SocketResult<()> { - match self.tcp_state { - TcpState::Default(..) | TcpState::Bound(..) => {} - - TcpState::Connecting(..) | TcpState::ConnectReady(..) => { - return Err(ErrorCode::ConcurrencyConflict.into()); - } - - _ => return Err(ErrorCode::InvalidState.into()), - }; - - if !is_valid_unicast_address(remote_address.ip()) - || !is_valid_remote_address(remote_address) - || !is_valid_address_family(remote_address.ip(), self.family) - { - return Err(ErrorCode::InvalidArgument.into()); - }; - - let (TcpState::Default(tokio_socket) | TcpState::Bound(tokio_socket)) = - std::mem::replace(&mut self.tcp_state, TcpState::Closed) - else { - unreachable!(); - }; - - let future = tokio_socket.connect(remote_address); - - self.tcp_state = TcpState::Connecting(Box::pin(future)); - Ok(()) - } - - pub(crate) fn finish_connect(&mut self) -> SocketResult<(DynInputStream, DynOutputStream)> { - let previous_state = std::mem::replace(&mut self.tcp_state, TcpState::Closed); - let result = match previous_state { - TcpState::ConnectReady(result) => result, - TcpState::Connecting(mut future) => { - let mut cx = std::task::Context::from_waker(Waker::noop()); - match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) { - Poll::Ready(result) => result, - Poll::Pending => { - self.tcp_state = TcpState::Connecting(future); - return Err(ErrorCode::WouldBlock.into()); - } - } - } - previous_state => { - self.tcp_state = previous_state; - return Err(ErrorCode::NotInProgress.into()); - } - }; - - match result { - Ok(stream) => { - let stream = Arc::new(stream); - let reader = Arc::new(Mutex::new(TcpReader::new(stream.clone()))); - let writer = Arc::new(Mutex::new(TcpWriter::new(stream.clone()))); - self.tcp_state = TcpState::Connected { - stream, - reader: reader.clone(), - writer: writer.clone(), - }; - let input: DynInputStream = Box::new(TcpReadStream(reader)); - let output: DynOutputStream = Box::new(TcpWriteStream(writer)); - Ok((input, output)) - } - Err(err) => { - self.tcp_state = TcpState::Closed; - Err(err.into()) - } - } - } - - pub(crate) fn start_listen(&mut self) -> SocketResult<()> { - match std::mem::replace(&mut self.tcp_state, TcpState::Closed) { - TcpState::Bound(tokio_socket) => { - self.tcp_state = TcpState::ListenStarted(tokio_socket); - Ok(()) - } - TcpState::ListenStarted(tokio_socket) => { - self.tcp_state = TcpState::ListenStarted(tokio_socket); - Err(ErrorCode::ConcurrencyConflict.into()) - } - previous_state => { - self.tcp_state = previous_state; - Err(ErrorCode::InvalidState.into()) - } - } - } - - pub(crate) fn finish_listen(&mut self) -> SocketResult<()> { - let tokio_socket = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) { - TcpState::ListenStarted(tokio_socket) => tokio_socket, - previous_state => { - self.tcp_state = previous_state; - return Err(ErrorCode::NotInProgress.into()); - } - }; - - match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) { - Ok(listener) => { - self.tcp_state = TcpState::Listening { - listener, - pending_accept: None, - }; - Ok(()) - } - Err(err) => { - self.tcp_state = TcpState::Closed; - - Err(match Errno::from_io_error(&err) { - // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE - // According to the docs, `listen` can return EMFILE on Windows. - // This is odd, because we're not trying to create a new socket - // or file descriptor of any kind. So we rewrite it to less - // surprising error code. - // - // At the time of writing, this behavior has never been experimentally - // observed by any of the wasmtime authors, so we're relying fully - // on Microsoft's documentation here. - #[cfg(windows)] - Some(Errno::MFILE) => Errno::NOBUFS.into(), - - _ => err.into(), - }) - } - } - } - - pub(crate) fn accept(&mut self) -> SocketResult<(Self, DynInputStream, DynOutputStream)> { - let TcpState::Listening { - listener, - pending_accept, - } = &mut self.tcp_state - else { - return Err(ErrorCode::InvalidState.into()); - }; - - let result = match pending_accept.take() { - Some(result) => result, - None => { - let mut cx = std::task::Context::from_waker(Waker::noop()); - match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx)) - .map_ok(|(stream, _)| stream) - { - Poll::Ready(result) => result, - Poll::Pending => Err(Errno::WOULDBLOCK.into()), - } - } - }; - - let client = result.map_err(|err| match Errno::from_io_error(&err) { - // From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS - // > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress, - // > or the service provider is still processing a callback function. - // - // wasi-sockets doesn't have an equivalent to the EINPROGRESS error, - // because in POSIX this error is only returned by a non-blocking - // `connect` and wasi-sockets has a different solution for that. - #[cfg(windows)] - Some(Errno::INPROGRESS) => Errno::INTR.into(), - - // Normalize Linux' non-standard behavior. - // - // From https://man7.org/linux/man-pages/man2/accept.2.html: - // > Linux accept() passes already-pending network errors on the - // > new socket as an error code from accept(). This behavior - // > differs from other BSD socket implementations. (...) - #[cfg(target_os = "linux")] - Some( - Errno::CONNRESET - | Errno::NETRESET - | Errno::HOSTUNREACH - | Errno::HOSTDOWN - | Errno::NETDOWN - | Errno::NETUNREACH - | Errno::PROTO - | Errno::NOPROTOOPT - | Errno::NONET - | Errno::OPNOTSUPP, - ) => Errno::CONNABORTED.into(), - - _ => err, - })?; - - #[cfg(target_os = "macos")] - { - // Manually inherit socket options from listener. We only have to - // do this on platforms that don't already do this automatically - // and only if a specific value was explicitly set on the listener. - - if let Some(size) = self.receive_buffer_size { - _ = set_receive_buffer_size(&client, size); // Ignore potential error. - } - - if let Some(size) = self.send_buffer_size { - _ = set_send_buffer_size(&client, size); // Ignore potential error. - } - - // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't. - if let (SocketAddressFamily::Ipv6, Some(ttl)) = (self.family, self.hop_limit) { - _ = rustix::net::sockopt::set_ipv6_unicast_hops(&client, Some(ttl)); // Ignore potential error. - } - - if let Some(value) = self.keep_alive_idle_time { - _ = set_keep_alive_idle_time(&client, value); // Ignore potential error. - } - } - - let client = Arc::new(client); - + pub(crate) fn p2_streams(&mut self) -> SocketResult<(DynInputStream, DynOutputStream)> { + let client = self.tcp_stream_arc()?; let reader = Arc::new(Mutex::new(TcpReader::new(client.clone()))); let writer = Arc::new(Mutex::new(TcpWriter::new(client.clone()))); - - let input: DynInputStream = Box::new(TcpReadStream(reader.clone())); - let output: DynOutputStream = Box::new(TcpWriteStream(writer.clone())); - let tcp_socket = TcpSocket::from_state( - TcpState::Connected { - stream: client, - reader, - writer, - }, - self.family, - )?; - - Ok((tcp_socket, input, output)) - } - - pub(crate) fn local_address(&self) -> SocketResult { - let view = match self.tcp_state { - TcpState::Default(..) => return Err(ErrorCode::InvalidState.into()), - TcpState::BindStarted(..) => return Err(ErrorCode::ConcurrencyConflict.into()), - _ => self.as_std_view()?, - }; - - Ok(view.local_addr()?) - } - - pub(crate) fn remote_address(&self) -> SocketResult { - let view = match self.tcp_state { - TcpState::Connected { .. } => self.as_std_view()?, - TcpState::Connecting(..) | TcpState::ConnectReady(..) => { - return Err(ErrorCode::ConcurrencyConflict.into()); - } - _ => return Err(ErrorCode::InvalidState.into()), - }; - - Ok(view.peer_addr()?) - } - - pub(crate) fn is_listening(&self) -> bool { - matches!(self.tcp_state, TcpState::Listening { .. }) - } - - pub(crate) fn address_family(&self) -> SocketAddressFamily { - self.family - } - - pub(crate) fn set_listen_backlog_size(&mut self, value: u32) -> SocketResult<()> { - const MIN_BACKLOG: u32 = 1; - const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further. - - if value == 0 { - return Err(ErrorCode::InvalidArgument.into()); - } - - // Silently clamp backlog size. This is OK for us to do, because operating systems do this too. - let value = value.clamp(MIN_BACKLOG, MAX_BACKLOG); - - match &self.tcp_state { - TcpState::Default(..) | TcpState::Bound(..) => { - // Socket not listening yet. Stash value for first invocation to `listen`. - } - TcpState::Listening { listener, .. } => { - // Try to update the backlog by calling `listen` again. - // Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact. - - rustix::net::listen(&listener, value.try_into().unwrap()) - .map_err(|_| ErrorCode::NotSupported)?; - } - _ => return Err(ErrorCode::InvalidState.into()), - } - self.listen_backlog_size = value; - - Ok(()) - } - - pub(crate) fn keep_alive_enabled(&self) -> SocketResult { - let view = &*self.as_std_view()?; - Ok(sockopt::socket_keepalive(view)?) - } - - pub(crate) fn set_keep_alive_enabled(&self, value: bool) -> SocketResult<()> { - let view = &*self.as_std_view()?; - Ok(sockopt::set_socket_keepalive(view, value)?) - } - - pub(crate) fn keep_alive_idle_time(&self) -> SocketResult { - let view = &*self.as_std_view()?; - Ok(sockopt::tcp_keepidle(view)?) - } - - pub(crate) fn set_keep_alive_idle_time(&mut self, value: u64) -> SocketResult<()> { - { - let view = &*self.as_std_view()?; - set_keep_alive_idle_time(view, value)?; - } - - #[cfg(target_os = "macos")] - { - self.keep_alive_idle_time = Some(value); - } - - Ok(()) - } - - pub(crate) fn keep_alive_interval(&self) -> SocketResult { - let view = &*self.as_std_view()?; - Ok(sockopt::tcp_keepintvl(view)?) - } - - pub(crate) fn set_keep_alive_interval( - &self, - duration: std::time::Duration, - ) -> SocketResult<()> { - let view = &*self.as_std_view()?; - Ok(set_keep_alive_interval(view, duration)?) - } - - pub(crate) fn keep_alive_count(&self) -> SocketResult { - let view = &*self.as_std_view()?; - Ok(sockopt::tcp_keepcnt(view)?) - } - - pub(crate) fn set_keep_alive_count(&self, value: u32) -> SocketResult<()> { - let view = &*self.as_std_view()?; - Ok(set_keep_alive_count(view, value)?) - } - - pub(crate) fn hop_limit(&self) -> SocketResult { - let view = &*self.as_std_view()?; - - let ttl = get_unicast_hop_limit(view, self.family)?; - Ok(ttl) - } - - pub(crate) fn set_hop_limit(&mut self, value: u8) -> SocketResult<()> { - { - let view = &*self.as_std_view()?; - - set_unicast_hop_limit(view, self.family, value)?; - } - - #[cfg(target_os = "macos")] - { - self.hop_limit = Some(value); - } - - Ok(()) - } - - pub(crate) fn receive_buffer_size(&self) -> SocketResult { - let view = &*self.as_std_view()?; - - Ok(receive_buffer_size(view)?) - } - - pub(crate) fn set_receive_buffer_size(&mut self, value: u64) -> SocketResult<()> { - { - let view = &*self.as_std_view()?; - - set_receive_buffer_size(view, value)?; - } - - #[cfg(target_os = "macos")] - { - self.receive_buffer_size = Some(value); - } - - Ok(()) - } - - pub(crate) fn send_buffer_size(&self) -> SocketResult { - let view = &*self.as_std_view()?; - - Ok(send_buffer_size(view)?) + self.set_p2_streaming_state(P2TcpStreamingState { + stream: client.clone(), + reader: reader.clone(), + writer: writer.clone(), + })?; + let input: DynInputStream = Box::new(TcpReadStream(reader)); + let output: DynOutputStream = Box::new(TcpWriteStream(writer)); + Ok((input, output)) } +} - pub(crate) fn set_send_buffer_size(&mut self, value: u64) -> SocketResult<()> { - { - let view = &*self.as_std_view()?; - - set_send_buffer_size(view, value)?; - } - - #[cfg(target_os = "macos")] - { - self.send_buffer_size = Some(value); - } - - Ok(()) - } +pub(crate) struct P2TcpStreamingState { + pub(crate) stream: Arc, + reader: Arc>, + writer: Arc>, +} +impl P2TcpStreamingState { pub(crate) fn shutdown(&self, how: Shutdown) -> SocketResult<()> { - let TcpState::Connected { reader, writer, .. } = &self.tcp_state else { - return Err(ErrorCode::InvalidState.into()); - }; - if let Shutdown::Both | Shutdown::Read = how { - try_lock_for_socket(reader)?.shutdown(); + try_lock_for_socket(&self.reader)?.shutdown(); } if let Shutdown::Both | Shutdown::Write = how { - try_lock_for_socket(writer)?.shutdown(); + try_lock_for_socket(&self.writer)?.shutdown(); } Ok(()) } } -#[async_trait::async_trait] -impl Pollable for TcpSocket { - async fn ready(&mut self) { - match &mut self.tcp_state { - TcpState::Default(..) - | TcpState::BindStarted(..) - | TcpState::Bound(..) - | TcpState::ListenStarted(..) - | TcpState::ConnectReady(..) - | TcpState::Closed - | TcpState::Connected { .. } => { - // No async operation in progress. - } - TcpState::Connecting(future) => { - self.tcp_state = TcpState::ConnectReady(future.as_mut().await); - } - TcpState::Listening { - listener, - pending_accept, - } => match pending_accept { - Some(_) => {} - None => { - let result = futures::future::poll_fn(|cx| { - listener.poll_accept(cx).map_ok(|(stream, _)| stream) - }) - .await; - *pending_accept = Some(result); - } - }, - } - } -} - struct TcpReader { stream: Arc, closed: bool, @@ -964,7 +352,7 @@ fn try_lock_for_stream(mutex: &Mutex) -> Result(mutex: &Mutex) -> Result, SocketError> { +fn try_lock_for_socket(mutex: &Mutex) -> SocketResult> { mutex.try_lock().map_err(|_| { SocketError::trap(anyhow::anyhow!( "concurrent access to resource not supported" diff --git a/crates/wasi/src/p3/bindings.rs b/crates/wasi/src/p3/bindings.rs index 89c9bd5888b8..c9c905c75b8e 100644 --- a/crates/wasi/src/p3/bindings.rs +++ b/crates/wasi/src/p3/bindings.rs @@ -94,7 +94,7 @@ mod generated { with: { "wasi:cli/terminal-input/terminal-input": crate::p3::cli::TerminalInput, "wasi:cli/terminal-output/terminal-output": crate::p3::cli::TerminalOutput, - "wasi:sockets/types/tcp-socket": crate::p3::sockets::tcp::TcpSocket, + "wasi:sockets/types/tcp-socket": crate::sockets::TcpSocket, "wasi:sockets/types/udp-socket": crate::sockets::UdpSocket, }, trappable_error_type: { diff --git a/crates/wasi/src/p3/sockets/conv.rs b/crates/wasi/src/p3/sockets/conv.rs index 03e1aff33a15..915bd3e96271 100644 --- a/crates/wasi/src/p3/sockets/conv.rs +++ b/crates/wasi/src/p3/sockets/conv.rs @@ -131,6 +131,15 @@ impl From for types::IpAddressFamily { } } +impl From for SocketAddressFamily { + fn from(family: types::IpAddressFamily) -> Self { + match family { + types::IpAddressFamily::Ipv4 => Self::Ipv4, + types::IpAddressFamily::Ipv6 => Self::Ipv6, + } + } +} + impl From for types::ErrorCode { fn from(value: std::io::Error) -> Self { (&value).into() @@ -231,6 +240,7 @@ impl From for types::ErrorCode { crate::sockets::util::ErrorCode::ConnectionAborted => Self::ConnectionAborted, crate::sockets::util::ErrorCode::DatagramTooLarge => Self::DatagramTooLarge, crate::sockets::util::ErrorCode::NotInProgress => Self::InvalidState, + crate::sockets::util::ErrorCode::ConcurrencyConflict => Self::InvalidState, } } } diff --git a/crates/wasi/src/p3/sockets/host/types/tcp.rs b/crates/wasi/src/p3/sockets/host/types/tcp.rs index 221f5bcba0b1..b23f9ed62fab 100644 --- a/crates/wasi/src/p3/sockets/host/types/tcp.rs +++ b/crates/wasi/src/p3/sockets/host/types/tcp.rs @@ -5,19 +5,13 @@ use crate::p3::bindings::sockets::types::{ Duration, ErrorCode, HostTcpSocket, HostTcpSocketWithStore, IpAddressFamily, IpSocketAddress, TcpSocket, }; -use crate::p3::sockets::tcp::{NonInheritedOptions, TcpState}; use crate::p3::sockets::{SocketResult, WasiSockets}; -use crate::sockets::util::{ - is_valid_address_family, is_valid_remote_address, is_valid_unicast_address, -}; -use crate::sockets::{SocketAddrUse, SocketAddressFamily, WasiSocketsCtxView}; -use anyhow::{Context as _, anyhow}; +use crate::sockets::{NonInheritedOptions, SocketAddrUse, SocketAddressFamily, WasiSocketsCtxView}; +use anyhow::Context; use bytes::BytesMut; use io_lifetimes::AsSocketlike as _; -use rustix::io::Errno; use std::future::poll_fn; use std::io::Cursor; -use std::mem; use std::net::{Shutdown, SocketAddr}; use std::pin::pin; use std::sync::Arc; @@ -28,10 +22,6 @@ use wasmtime::component::{ Resource, ResourceTable, StreamReader, StreamWriter, }; -fn is_tcp_allowed(store: &Accessor) -> bool { - store.with(|mut view| view.get().ctx.allowed_network_uses.tcp) -} - fn get_socket<'a>( table: &'a ResourceTable, socket: &'a Resource, @@ -74,50 +64,12 @@ impl AccessorTask> for ListenTask { }) else { return Ok(()); }; - let state = match res { - Ok((stream, _addr)) => { - self.options.apply(self.family, &stream); - TcpState::Connected(Arc::new(stream)) - } - Err(err) => { - match Errno::from_io_error(&err) { - // From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS - // > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress, - // > or the service provider is still processing a callback function. - // - // wasi-sockets doesn't have an equivalent to the EINPROGRESS error, - // because in POSIX this error is only returned by a non-blocking - // `connect` and wasi-sockets has a different solution for that. - #[cfg(windows)] - Some(Errno::INPROGRESS) => TcpState::Error(ErrorCode::Unknown), - - // Normalize Linux' non-standard behavior. - // - // From https://man7.org/linux/man-pages/man2/accept.2.html: - // > Linux accept() passes already-pending network errors on the - // > new socket as an error code from accept(). This behavior - // > differs from other BSD socket implementations. (...) - #[cfg(target_os = "linux")] - Some( - Errno::CONNRESET - | Errno::NETRESET - | Errno::HOSTUNREACH - | Errno::HOSTDOWN - | Errno::NETDOWN - | Errno::NETUNREACH - | Errno::PROTO - | Errno::NOPROTOOPT - | Errno::NONET - | Errno::OPNOTSUPP, - ) => TcpState::Error(ErrorCode::ConnectionAborted), - _ => TcpState::Error(err.into()), - } - } - }; + let socket = TcpSocket::new_accept(res.map(|p| p.0), &self.options, self.family) + .unwrap_or_else(|err| TcpSocket::new_error(err, self.family)); let socket = store.with(|mut view| { view.get() .table - .push(TcpSocket::from_state(state, self.family)) + .push(socket) .context("failed to push socket resource to table") })?; if let Some(socket) = tx.write(Some(socket)).await { @@ -215,14 +167,13 @@ impl HostTcpSocketWithStore for WasiSockets { local_address: IpSocketAddress, ) -> SocketResult<()> { let local_address = SocketAddr::from(local_address); - if !is_tcp_allowed(store) - || !is_addr_allowed(store, local_address, SocketAddrUse::TcpBind).await - { + if !is_addr_allowed(store, local_address, SocketAddrUse::TcpBind).await { return Err(ErrorCode::AccessDenied.into()); } store.with(|mut view| { let socket = get_socket_mut(view.get().table, &socket)?; - socket.bind(local_address)?; + socket.start_bind(local_address)?; + socket.finish_bind()?; Ok(()) }) } @@ -233,47 +184,22 @@ impl HostTcpSocketWithStore for WasiSockets { remote_address: IpSocketAddress, ) -> SocketResult<()> { let remote_address = SocketAddr::from(remote_address); - if !is_tcp_allowed(store) - || !is_addr_allowed(store, remote_address, SocketAddrUse::TcpConnect).await - { + if !is_addr_allowed(store, remote_address, SocketAddrUse::TcpConnect).await { return Err(ErrorCode::AccessDenied.into()); } + let addr = remote_address.into(); let sock = store.with(|mut view| -> SocketResult<_> { - let ip = remote_address.ip(); let socket = get_socket_mut(view.get().table, &socket)?; - if !is_valid_unicast_address(ip) - || !is_valid_remote_address(remote_address) - || !is_valid_address_family(ip, socket.family) - { - return Err(ErrorCode::InvalidArgument.into()); - } - match mem::replace(&mut socket.tcp_state, TcpState::Connecting) { - TcpState::Default(sock) | TcpState::Bound(sock) => Ok(sock), - tcp_state => { - socket.tcp_state = tcp_state; - Err(ErrorCode::InvalidState.into()) - } - } + Ok(socket.start_connect(&addr)?) })?; // FIXME: handle possible cancellation of the outer `connect` // https://github.com/bytecodealliance/wasmtime/pull/11291#discussion_r2223917986 - let res = sock.connect(remote_address).await; + let res = sock.connect(addr).await; store.with(|mut view| -> SocketResult<_> { let socket = get_socket_mut(view.get().table, &socket)?; - if !matches!(socket.tcp_state, TcpState::Connecting) { - return Err(TrappableError::trap(anyhow!("corrupted socket state"))); - } - match res { - Ok(stream) => { - socket.tcp_state = TcpState::Connected(Arc::new(stream)); - Ok(()) - } - Err(err) => { - socket.tcp_state = TcpState::Closed; - Err(ErrorCode::from(err).into()) - } - } + socket.finish_connect(res)?; + Ok(()) }) } @@ -282,46 +208,12 @@ impl HostTcpSocketWithStore for WasiSockets { socket: Resource, ) -> SocketResult>> { store.with(|mut view| { - if !view.get().ctx.allowed_network_uses.tcp { - return Err(ErrorCode::AccessDenied.into()); - } - let TcpSocket { - tcp_state, - listen_backlog_size, - family, - options, - } = get_socket_mut(view.get().table, &socket)?; - let sock = match mem::replace(tcp_state, TcpState::Closed) { - TcpState::Default(sock) | TcpState::Bound(sock) => sock, - prev => { - *tcp_state = prev; - return Err(ErrorCode::InvalidState.into()); - } - }; - let listener = match sock.listen(*listen_backlog_size) { - Ok(listener) => listener, - Err(err) => { - match Errno::from_io_error(&err) { - // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE - // According to the docs, `listen` can return EMFILE on Windows. - // This is odd, because we're not trying to create a new socket - // or file descriptor of any kind. So we rewrite it to less - // surprising error code. - // - // At the time of writing, this behavior has never been experimentally - // observed by any of the wasmtime authors, so we're relying fully - // on Microsoft's documentation here. - #[cfg(windows)] - Some(Errno::MFILE) => return Err(ErrorCode::OutOfMemory.into()), - - _ => return Err(ErrorCode::from(err).into()), - } - } - }; - let listener = Arc::new(listener); - *tcp_state = TcpState::Listening(Arc::clone(&listener)); - let family = *family; - let options = options.clone(); + let socket = get_socket_mut(view.get().table, &socket)?; + socket.start_listen()?; + socket.finish_listen()?; + let listener = socket.tcp_listener_arc().unwrap().clone(); + let family = socket.address_family(); + let options = socket.non_inherited_options().clone(); let (tx, rx) = view .instance() .stream(&mut view) @@ -341,17 +233,14 @@ impl HostTcpSocketWithStore for WasiSockets { async fn send( store: &Accessor, socket: Resource, - data: StreamReader, + mut data: StreamReader, ) -> SocketResult<()> { - let (stream, mut data) = store.with(|mut view| -> SocketResult<_> { + let stream = store.with(|mut view| -> SocketResult<_> { let sock = get_socket(view.get().table, &socket)?; - if let TcpState::Connected(stream) | TcpState::Receiving(stream) = &sock.tcp_state { - Ok((Arc::clone(&stream), data)) - } else { - Err(ErrorCode::InvalidState.into()) - } + let stream = sock.tcp_stream_arc()?; + Ok(Arc::clone(stream)) })?; - let mut buf = Vec::with_capacity(8096); + let mut buf = Vec::with_capacity(DEFAULT_BUFFER_CAPACITY); let mut result = Ok(()); while !data.is_closed() { buf = data.read(store, buf).await; @@ -388,10 +277,10 @@ impl HostTcpSocketWithStore for WasiSockets { let (mut data_tx, data_rx) = instance .stream(&mut view) .context("failed to create stream")?; - let TcpSocket { tcp_state, .. } = get_socket_mut(view.get().table, &socket)?; - match mem::replace(tcp_state, TcpState::Closed) { - TcpState::Connected(stream) => { - *tcp_state = TcpState::Receiving(Arc::clone(&stream)); + let socket = get_socket_mut(view.get().table, &socket)?; + match socket.start_receive() { + Some(stream) => { + let stream = stream.clone(); let (result_tx, result_rx) = instance .future(&mut view, || unreachable!()) .context("failed to create future")?; @@ -402,8 +291,7 @@ impl HostTcpSocketWithStore for WasiSockets { }); Ok((data_rx, result_rx)) } - prev => { - *tcp_state = prev; + None => { let (mut result_tx, result_rx) = instance .future(&mut view, || Err(ErrorCode::InvalidState)) .context("failed to create future")?; @@ -418,7 +306,9 @@ impl HostTcpSocketWithStore for WasiSockets { impl HostTcpSocket for WasiSocketsCtxView<'_> { fn new(&mut self, address_family: IpAddressFamily) -> wasmtime::Result> { - let socket = TcpSocket::new(address_family.into()).context("failed to create socket")?; + let family = address_family.into(); + let socket = + TcpSocket::new(self.ctx, family).unwrap_or_else(|e| TcpSocket::new_error(e, family)); self.table .push(socket) .context("failed to push socket resource to table") @@ -426,12 +316,12 @@ impl HostTcpSocket for WasiSocketsCtxView<'_> { fn local_address(&mut self, socket: Resource) -> SocketResult { let sock = get_socket(self.table, &socket)?; - Ok(sock.local_address()?) + Ok(sock.local_address()?.into()) } fn remote_address(&mut self, socket: Resource) -> SocketResult { let sock = get_socket(self.table, &socket)?; - Ok(sock.remote_address()?) + Ok(sock.remote_address()?.into()) } fn is_listening(&mut self, socket: Resource) -> wasmtime::Result { @@ -441,7 +331,7 @@ impl HostTcpSocket for WasiSocketsCtxView<'_> { fn address_family(&mut self, socket: Resource) -> wasmtime::Result { let sock = get_socket(self.table, &socket)?; - Ok(sock.address_family()) + Ok(sock.address_family().into()) } fn set_listen_backlog_size( diff --git a/crates/wasi/src/p3/sockets/mod.rs b/crates/wasi/src/p3/sockets/mod.rs index 07b4db6bd3d1..d8c0bf5cec34 100644 --- a/crates/wasi/src/p3/sockets/mod.rs +++ b/crates/wasi/src/p3/sockets/mod.rs @@ -5,7 +5,6 @@ use wasmtime::component::Linker; mod conv; mod host; -pub mod tcp; pub type SocketResult = Result; pub type SocketError = TrappableError; diff --git a/crates/wasi/src/p3/sockets/tcp.rs b/crates/wasi/src/p3/sockets/tcp.rs deleted file mode 100644 index e6b7c8e99481..000000000000 --- a/crates/wasi/src/p3/sockets/tcp.rs +++ /dev/null @@ -1,409 +0,0 @@ -use core::fmt::Debug; -use core::mem; -use core::net::SocketAddr; - -use std::sync::Arc; - -use cap_net_ext::AddressFamily; -use io_lifetimes::AsSocketlike as _; -use io_lifetimes::views::SocketlikeView; -use rustix::net::sockopt; - -use crate::p3::bindings::sockets::types::{Duration, ErrorCode, IpAddressFamily, IpSocketAddress}; -use crate::runtime::with_ambient_tokio_runtime; -use crate::sockets::util::{ - get_unicast_hop_limit, is_valid_address_family, is_valid_unicast_address, receive_buffer_size, - send_buffer_size, set_keep_alive_count, set_keep_alive_idle_time, set_keep_alive_interval, - set_receive_buffer_size, set_send_buffer_size, set_unicast_hop_limit, tcp_bind, -}; -use crate::sockets::{DEFAULT_TCP_BACKLOG, SocketAddressFamily}; - -/// The state of a TCP socket. -/// -/// This represents the various states a socket can be in during the -/// activities of binding, listening, accepting, and connecting. -pub(crate) enum TcpState { - /// The initial state for a newly-created socket. - Default(tokio::net::TcpSocket), - - /// Binding finished. The socket has an address but is not yet listening for connections. - Bound(tokio::net::TcpSocket), - - /// The socket is now listening and waiting for an incoming connection. - Listening(Arc), - - /// An outgoing connection is started. - Connecting, - - /// A connection has been established. - Connected(Arc), - - /// A connection has been established and `receive` has been called. - Receiving(Arc), - - Error(ErrorCode), - - Closed, -} - -impl Debug for TcpState { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Default(_) => f.debug_tuple("Default").finish(), - Self::Bound(_) => f.debug_tuple("Bound").finish(), - Self::Listening { .. } => f.debug_tuple("Listening").finish(), - Self::Connecting => f.debug_tuple("Connecting").finish(), - Self::Connected { .. } => f.debug_tuple("Connected").finish(), - Self::Receiving { .. } => f.debug_tuple("Receiving").finish(), - Self::Error(..) => f.debug_tuple("Error").finish(), - Self::Closed => write!(f, "Closed"), - } - } -} - -/// A host TCP socket, plus associated bookkeeping. -pub struct TcpSocket { - /// The current state in the bind/listen/accept/connect progression. - pub(crate) tcp_state: TcpState, - - /// The desired listen queue size. - pub(crate) listen_backlog_size: u32, - - pub(crate) family: SocketAddressFamily, - - pub(crate) options: NonInheritedOptions, -} - -impl TcpSocket { - /// Create a new socket in the given family. - pub(crate) fn new(family: AddressFamily) -> std::io::Result { - with_ambient_tokio_runtime(|| { - let (socket, family) = match family { - AddressFamily::Ipv4 => { - let socket = tokio::net::TcpSocket::new_v4()?; - (socket, SocketAddressFamily::Ipv4) - } - AddressFamily::Ipv6 => { - let socket = tokio::net::TcpSocket::new_v6()?; - sockopt::set_ipv6_v6only(&socket, true)?; - (socket, SocketAddressFamily::Ipv6) - } - }; - - Ok(Self::from_state(TcpState::Default(socket), family)) - }) - } - - /// Create a `TcpSocket` from an existing socket. - pub(crate) fn from_state(state: TcpState, family: SocketAddressFamily) -> Self { - Self { - tcp_state: state, - listen_backlog_size: DEFAULT_TCP_BACKLOG, - family, - options: Default::default(), - } - } - - pub(crate) fn as_std_view(&self) -> Result, ErrorCode> { - match &self.tcp_state { - TcpState::Default(socket) | TcpState::Bound(socket) => Ok(socket.as_socketlike_view()), - TcpState::Connected(stream) | TcpState::Receiving(stream) => { - Ok(stream.as_socketlike_view()) - } - TcpState::Listening(listener) => Ok(listener.as_socketlike_view()), - TcpState::Connecting | TcpState::Closed => Err(ErrorCode::InvalidState), - TcpState::Error(err) => Err(*err), - } - } - - pub(crate) fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> { - let ip = addr.ip(); - if !is_valid_unicast_address(ip) || !is_valid_address_family(ip, self.family) { - return Err(ErrorCode::InvalidArgument); - } - match mem::replace(&mut self.tcp_state, TcpState::Closed) { - TcpState::Default(sock) => { - if let Err(err) = tcp_bind(&sock, addr) { - self.tcp_state = TcpState::Default(sock); - Err(err.into()) - } else { - self.tcp_state = TcpState::Bound(sock); - Ok(()) - } - } - tcp_state => { - self.tcp_state = tcp_state; - Err(ErrorCode::InvalidState) - } - } - } - - pub(crate) fn local_address(&self) -> Result { - match &self.tcp_state { - TcpState::Bound(socket) => { - let addr = socket.local_addr()?; - Ok(addr.into()) - } - TcpState::Connected(stream) | TcpState::Receiving(stream) => { - let addr = stream.local_addr()?; - Ok(addr.into()) - } - TcpState::Listening(listener) => { - let addr = listener.local_addr()?; - Ok(addr.into()) - } - TcpState::Error(err) => Err(*err), - _ => Err(ErrorCode::InvalidState), - } - } - - pub(crate) fn remote_address(&self) -> Result { - match &self.tcp_state { - TcpState::Connected(stream) | TcpState::Receiving(stream) => { - let addr = stream.peer_addr()?; - Ok(addr.into()) - } - TcpState::Error(err) => Err(*err), - _ => Err(ErrorCode::InvalidState), - } - } - - pub(crate) fn is_listening(&self) -> bool { - matches!(self.tcp_state, TcpState::Listening { .. }) - } - - pub(crate) fn address_family(&self) -> IpAddressFamily { - match self.family { - SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4, - SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6, - } - } - - pub(crate) fn set_listen_backlog_size(&mut self, value: u64) -> Result<(), ErrorCode> { - const MIN_BACKLOG: u32 = 1; - const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further. - - if value == 0 { - return Err(ErrorCode::InvalidArgument); - } - // Silently clamp backlog size. This is OK for us to do, because operating systems do this too. - let value = value - .try_into() - .unwrap_or(MAX_BACKLOG) - .clamp(MIN_BACKLOG, MAX_BACKLOG); - match &self.tcp_state { - TcpState::Default(..) | TcpState::Bound(..) => { - // Socket not listening yet. Stash value for first invocation to `listen`. - self.listen_backlog_size = value; - Ok(()) - } - TcpState::Listening(listener) => { - // Try to update the backlog by calling `listen` again. - // Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact. - if rustix::net::listen(&listener, value.try_into().unwrap_or(i32::MAX)).is_err() { - return Err(ErrorCode::NotSupported); - } - self.listen_backlog_size = value; - Ok(()) - } - TcpState::Error(err) => Err(*err), - _ => Err(ErrorCode::InvalidState), - } - } - - pub(crate) fn keep_alive_enabled(&self) -> Result { - let fd = &*self.as_std_view()?; - let v = sockopt::socket_keepalive(fd)?; - Ok(v) - } - - pub(crate) fn set_keep_alive_enabled(&self, value: bool) -> Result<(), ErrorCode> { - let fd = &*self.as_std_view()?; - sockopt::set_socket_keepalive(fd, value)?; - Ok(()) - } - - pub(crate) fn keep_alive_idle_time(&self) -> Result { - let fd = &*self.as_std_view()?; - let v = sockopt::tcp_keepidle(fd)?; - Ok(v.as_nanos().try_into().unwrap_or(u64::MAX)) - } - - pub(crate) fn set_keep_alive_idle_time(&mut self, value: Duration) -> Result<(), ErrorCode> { - let value = { - let fd = self.as_std_view()?; - set_keep_alive_idle_time(&*fd, value)? - }; - self.options.set_keep_alive_idle_time(value); - Ok(()) - } - - pub(crate) fn keep_alive_interval(&self) -> Result { - let fd = &*self.as_std_view()?; - let v = sockopt::tcp_keepintvl(fd)?; - Ok(v.as_nanos().try_into().unwrap_or(u64::MAX)) - } - - pub(crate) fn set_keep_alive_interval(&self, value: Duration) -> Result<(), ErrorCode> { - let fd = &*self.as_std_view()?; - set_keep_alive_interval(fd, core::time::Duration::from_nanos(value))?; - Ok(()) - } - - pub(crate) fn keep_alive_count(&self) -> Result { - let fd = &*self.as_std_view()?; - let v = sockopt::tcp_keepcnt(fd)?; - Ok(v) - } - - pub(crate) fn set_keep_alive_count(&self, value: u32) -> Result<(), ErrorCode> { - let fd = &*self.as_std_view()?; - set_keep_alive_count(fd, value)?; - Ok(()) - } - - pub(crate) fn hop_limit(&self) -> Result { - let fd = &*self.as_std_view()?; - let n = get_unicast_hop_limit(fd, self.family)?; - Ok(n) - } - - pub(crate) fn set_hop_limit(&mut self, value: u8) -> Result<(), ErrorCode> { - { - let fd = &*self.as_std_view()?; - set_unicast_hop_limit(fd, self.family, value)?; - } - self.options.set_hop_limit(value); - Ok(()) - } - - pub(crate) fn receive_buffer_size(&self) -> Result { - let fd = &*self.as_std_view()?; - let n = receive_buffer_size(fd)?; - Ok(n) - } - - pub(crate) fn set_receive_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> { - let res = { - let fd = &*self.as_std_view()?; - set_receive_buffer_size(fd, value)? - }; - self.options.set_receive_buffer_size(res); - Ok(()) - } - - pub(crate) fn send_buffer_size(&self) -> Result { - let fd = &*self.as_std_view()?; - let n = send_buffer_size(fd)?; - Ok(n) - } - - pub(crate) fn set_send_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> { - let res = { - let fd = &*self.as_std_view()?; - set_send_buffer_size(fd, value)? - }; - self.options.set_send_buffer_size(res); - Ok(()) - } -} - -#[cfg(not(target_os = "macos"))] -pub use inherits_option::*; -#[cfg(not(target_os = "macos"))] -mod inherits_option { - use crate::sockets::SocketAddressFamily; - use tokio::net::TcpStream; - - #[derive(Default, Clone)] - pub struct NonInheritedOptions; - - impl NonInheritedOptions { - pub fn set_keep_alive_idle_time(&mut self, _value: u64) {} - - pub fn set_hop_limit(&mut self, _value: u8) {} - - pub fn set_receive_buffer_size(&mut self, _value: usize) {} - - pub fn set_send_buffer_size(&mut self, _value: usize) {} - - pub(crate) fn apply(&self, _family: SocketAddressFamily, _stream: &TcpStream) {} - } -} - -#[cfg(target_os = "macos")] -pub use does_not_inherit_options::*; -#[cfg(target_os = "macos")] -mod does_not_inherit_options { - use crate::sockets::SocketAddressFamily; - use rustix::net::sockopt; - use std::sync::Arc; - use std::sync::atomic::{AtomicU8, AtomicU64, AtomicUsize, Ordering::Relaxed}; - use std::time::Duration; - use tokio::net::TcpStream; - - // The socket options below are not automatically inherited from the listener - // on all platforms. So we keep track of which options have been explicitly - // set and manually apply those values to newly accepted clients. - #[derive(Default, Clone)] - pub struct NonInheritedOptions(Arc); - - #[derive(Default)] - struct Inner { - receive_buffer_size: AtomicUsize, - send_buffer_size: AtomicUsize, - hop_limit: AtomicU8, - keep_alive_idle_time: AtomicU64, // nanoseconds - } - - impl NonInheritedOptions { - pub fn set_keep_alive_idle_time(&mut self, value: u64) { - self.0.keep_alive_idle_time.store(value, Relaxed); - } - - pub fn set_hop_limit(&mut self, value: u8) { - self.0.hop_limit.store(value, Relaxed); - } - - pub fn set_receive_buffer_size(&mut self, value: usize) { - self.0.receive_buffer_size.store(value, Relaxed); - } - - pub fn set_send_buffer_size(&mut self, value: usize) { - self.0.send_buffer_size.store(value, Relaxed); - } - - pub(crate) fn apply(&self, family: SocketAddressFamily, stream: &TcpStream) { - // Manually inherit socket options from listener. We only have to - // do this on platforms that don't already do this automatically - // and only if a specific value was explicitly set on the listener. - - let receive_buffer_size = self.0.receive_buffer_size.load(Relaxed); - if receive_buffer_size > 0 { - // Ignore potential error. - _ = sockopt::set_socket_recv_buffer_size(&stream, receive_buffer_size); - } - - let send_buffer_size = self.0.send_buffer_size.load(Relaxed); - if send_buffer_size > 0 { - // Ignore potential error. - _ = sockopt::set_socket_send_buffer_size(&stream, send_buffer_size); - } - - // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't. - if family == SocketAddressFamily::Ipv6 { - let hop_limit = self.0.hop_limit.load(Relaxed); - if hop_limit > 0 { - // Ignore potential error. - _ = sockopt::set_ipv6_unicast_hops(&stream, Some(hop_limit)); - } - } - - let keep_alive_idle_time = self.0.keep_alive_idle_time.load(Relaxed); - if keep_alive_idle_time > 0 { - // Ignore potential error. - _ = sockopt::set_tcp_keepidle(&stream, Duration::from_nanos(keep_alive_idle_time)); - } - } - } -} diff --git a/crates/wasi/src/sockets/mod.rs b/crates/wasi/src/sockets/mod.rs index ce78e78f08c6..3783895814e5 100644 --- a/crates/wasi/src/sockets/mod.rs +++ b/crates/wasi/src/sockets/mod.rs @@ -5,9 +5,13 @@ use std::pin::Pin; use std::sync::Arc; use wasmtime::component::{HasData, ResourceTable}; +mod tcp; mod udp; pub(crate) mod util; +#[cfg(feature = "p3")] +pub(crate) use tcp::NonInheritedOptions; +pub use tcp::TcpSocket; pub use udp::UdpSocket; pub(crate) struct WasiSockets; diff --git a/crates/wasi/src/sockets/tcp.rs b/crates/wasi/src/sockets/tcp.rs new file mode 100644 index 000000000000..09dd2ba8dea6 --- /dev/null +++ b/crates/wasi/src/sockets/tcp.rs @@ -0,0 +1,816 @@ +use crate::p2::P2TcpStreamingState; +use crate::runtime::with_ambient_tokio_runtime; +use crate::sockets::util::{ + ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address, + is_valid_unicast_address, receive_buffer_size, send_buffer_size, set_keep_alive_count, + set_keep_alive_idle_time, set_keep_alive_interval, set_receive_buffer_size, + set_send_buffer_size, set_unicast_hop_limit, tcp_bind, +}; +use crate::sockets::{DEFAULT_TCP_BACKLOG, SocketAddressFamily, WasiSocketsCtx}; +use io_lifetimes::AsSocketlike as _; +use io_lifetimes::views::SocketlikeView; +use rustix::io::Errno; +use rustix::net::sockopt; +use std::fmt::Debug; +use std::io; +use std::mem; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll, Waker}; +use std::time::Duration; + +/// The state of a TCP socket. +/// +/// This represents the various states a socket can be in during the +/// activities of binding, listening, accepting, and connecting. Note that this +/// state machine encompasses both WASIp2 and WASIp3. +enum TcpState { + /// The initial state for a newly-created socket. + /// + /// From here a socket can transition to `BindStarted`, `ListenStarted`, or + /// `Connecting`. + Default(tokio::net::TcpSocket), + + /// A state indicating that a bind has been started and must be finished + /// subsequently with `finish_bind`. + /// + /// From here a socket can transition to `Bound`. + BindStarted(tokio::net::TcpSocket), + + /// Binding finished. The socket has an address but is not yet listening for + /// connections. + /// + /// From here a socket can transition to `ListenStarted`, or `Connecting`. + Bound(tokio::net::TcpSocket), + + /// Listening on a socket has started and must be completed with + /// `finish_listen`. + /// + /// From here a socket can transition to `Listening`. + ListenStarted(tokio::net::TcpSocket), + + /// The socket is now listening and waiting for an incoming connection. + /// + /// Sockets will not leave this state. + Listening { + /// The raw tokio-basd TCP listener managing the underyling socket. + listener: Arc, + + /// The last-accepted connection, set during the `ready` method and read + /// during the `accept` method. Note that this is only used for WASIp2 + /// at this time. + pending_accept: Option>, + }, + + /// An outgoing connection is started. + /// + /// This is created via the `start_connect` method. The payload here is an + /// optionally-specified owned future for the result of the connect. In + /// WASIp2 the future lives here, but in WASIp3 it lives on the event loop + /// so this is `None`. + /// + /// From here a socket can transition to `ConnectReady` or `Connected`. + Connecting(Option> + Send>>>), + + /// A connection via `Connecting` has completed. + /// + /// This is present for WASIp2 where the `Connecting` state stores `Some` of + /// a future, and the result of that future is recorded here when it + /// finishes as part of the `ready` method. + /// + /// From here a socket can transition to `Connected`. + ConnectReady(io::Result), + + /// A connection has been established. + /// + /// This is created either via `finish_connect` or for freshly accepted + /// sockets from a TCP listener. + /// + /// From here a socket can transition to `Receiving` or `P2Streaming`. + Connected(Arc), + + /// A connection has been established and `receive` has been called. + /// + /// A socket will not transition out of this state. + #[cfg(feature = "p3")] + Receiving(Arc), + + /// This is a WASIp2-bound socket which stores some extra state for + /// read/write streams to handle TCP shutdown. + /// + /// A socket will not transition out of this state. + P2Streaming(Box), + + /// This is not actually a socket but a deferred error. + /// + /// This error came out of `accept` and is deferred until the socket is + /// operated on. + #[cfg(feature = "p3")] + Error(io::Error), + + /// The socket is closed and no more operations can be performed. + Closed, +} + +impl Debug for TcpState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Default(_) => f.debug_tuple("Default").finish(), + Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(), + Self::Bound(_) => f.debug_tuple("Bound").finish(), + Self::ListenStarted { .. } => f.debug_tuple("ListenStarted").finish(), + Self::Listening { .. } => f.debug_tuple("Listening").finish(), + Self::Connecting(..) => f.debug_tuple("Connecting").finish(), + Self::ConnectReady(..) => f.debug_tuple("ConnectReady").finish(), + Self::Connected { .. } => f.debug_tuple("Connected").finish(), + #[cfg(feature = "p3")] + Self::Receiving { .. } => f.debug_tuple("Receiving").finish(), + Self::P2Streaming(_) => f.debug_tuple("P2Streaming").finish(), + #[cfg(feature = "p3")] + Self::Error(..) => f.debug_tuple("Error").finish(), + Self::Closed => write!(f, "Closed"), + } + } +} + +/// A host TCP socket, plus associated bookkeeping. +pub struct TcpSocket { + /// The current state in the bind/listen/accept/connect progression. + tcp_state: TcpState, + + /// The desired listen queue size. + listen_backlog_size: u32, + + family: SocketAddressFamily, + + options: NonInheritedOptions, +} + +impl TcpSocket { + /// Create a new socket in the given family. + pub(crate) fn new(ctx: &WasiSocketsCtx, family: SocketAddressFamily) -> std::io::Result { + ctx.allowed_network_uses.check_allowed_tcp()?; + + with_ambient_tokio_runtime(|| { + let socket = match family { + SocketAddressFamily::Ipv4 => tokio::net::TcpSocket::new_v4()?, + SocketAddressFamily::Ipv6 => { + let socket = tokio::net::TcpSocket::new_v6()?; + sockopt::set_ipv6_v6only(&socket, true)?; + socket + } + }; + + Ok(Self::from_state(TcpState::Default(socket), family)) + }) + } + + #[cfg(feature = "p3")] + pub(crate) fn new_error(err: io::Error, family: SocketAddressFamily) -> Self { + TcpSocket::from_state(TcpState::Error(err), family) + } + + /// Creates a new socket with the `result` of an accepted socket from a + /// `TcpListener`. + /// + /// This will handle the `result` internally and `result` should be the raw + /// result from a TCP listen operation. + pub(crate) fn new_accept( + result: io::Result, + options: &NonInheritedOptions, + family: SocketAddressFamily, + ) -> io::Result { + let client = result.map_err(|err| match Errno::from_io_error(&err) { + // From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS + // > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress, + // > or the service provider is still processing a callback function. + // + // wasi-sockets doesn't have an equivalent to the EINPROGRESS error, + // because in POSIX this error is only returned by a non-blocking + // `connect` and wasi-sockets has a different solution for that. + #[cfg(windows)] + Some(Errno::INPROGRESS) => Errno::INTR.into(), + + // Normalize Linux' non-standard behavior. + // + // From https://man7.org/linux/man-pages/man2/accept.2.html: + // > Linux accept() passes already-pending network errors on the + // > new socket as an error code from accept(). This behavior + // > differs from other BSD socket implementations. (...) + #[cfg(target_os = "linux")] + Some( + Errno::CONNRESET + | Errno::NETRESET + | Errno::HOSTUNREACH + | Errno::HOSTDOWN + | Errno::NETDOWN + | Errno::NETUNREACH + | Errno::PROTO + | Errno::NOPROTOOPT + | Errno::NONET + | Errno::OPNOTSUPP, + ) => Errno::CONNABORTED.into(), + + _ => err, + })?; + options.apply(family, &client); + Ok(Self::from_state( + TcpState::Connected(Arc::new(client)), + family, + )) + } + + /// Create a `TcpSocket` from an existing socket. + fn from_state(state: TcpState, family: SocketAddressFamily) -> Self { + Self { + tcp_state: state, + listen_backlog_size: DEFAULT_TCP_BACKLOG, + family, + options: Default::default(), + } + } + + pub(crate) fn as_std_view(&self) -> Result, ErrorCode> { + match &self.tcp_state { + TcpState::Default(socket) + | TcpState::BindStarted(socket) + | TcpState::Bound(socket) + | TcpState::ListenStarted(socket) => Ok(socket.as_socketlike_view()), + TcpState::Connected(stream) => Ok(stream.as_socketlike_view()), + #[cfg(feature = "p3")] + TcpState::Receiving(stream) => Ok(stream.as_socketlike_view()), + TcpState::Listening { listener, .. } => Ok(listener.as_socketlike_view()), + TcpState::P2Streaming(state) => Ok(state.stream.as_socketlike_view()), + TcpState::Connecting(..) | TcpState::ConnectReady(_) | TcpState::Closed => { + Err(ErrorCode::InvalidState) + } + #[cfg(feature = "p3")] + TcpState::Error(err) => Err(err.into()), + } + } + + pub(crate) fn start_bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> { + let ip = addr.ip(); + if !is_valid_unicast_address(ip) || !is_valid_address_family(ip, self.family) { + return Err(ErrorCode::InvalidArgument); + } + match mem::replace(&mut self.tcp_state, TcpState::Closed) { + TcpState::Default(sock) => { + if let Err(err) = tcp_bind(&sock, addr) { + self.tcp_state = TcpState::Default(sock); + Err(err) + } else { + self.tcp_state = TcpState::BindStarted(sock); + Ok(()) + } + } + tcp_state => { + self.tcp_state = tcp_state; + Err(ErrorCode::InvalidState) + } + } + } + + pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> { + match mem::replace(&mut self.tcp_state, TcpState::Closed) { + TcpState::BindStarted(socket) => { + self.tcp_state = TcpState::Bound(socket); + Ok(()) + } + current_state => { + // Reset the state so that the outside world doesn't see this socket as closed + self.tcp_state = current_state; + Err(ErrorCode::NotInProgress) + } + } + } + + pub(crate) fn start_connect( + &mut self, + addr: &SocketAddr, + ) -> Result { + match self.tcp_state { + TcpState::Default(..) | TcpState::Bound(..) => {} + TcpState::Connecting(..) => { + return Err(ErrorCode::ConcurrencyConflict); + } + _ => return Err(ErrorCode::InvalidState), + }; + + if !is_valid_unicast_address(addr.ip()) + || !is_valid_remote_address(*addr) + || !is_valid_address_family(addr.ip(), self.family) + { + return Err(ErrorCode::InvalidArgument); + }; + + let (TcpState::Default(tokio_socket) | TcpState::Bound(tokio_socket)) = + mem::replace(&mut self.tcp_state, TcpState::Connecting(None)) + else { + unreachable!(); + }; + + Ok(tokio_socket) + } + + /// For WASIp2 this is used to record the actual connection future as part + /// of `start_connect` within this socket state. + pub(crate) fn set_pending_connect( + &mut self, + future: impl Future> + Send + 'static, + ) -> Result<(), ErrorCode> { + match &mut self.tcp_state { + TcpState::Connecting(slot @ None) => { + *slot = Some(Box::pin(future)); + Ok(()) + } + _ => Err(ErrorCode::InvalidState), + } + } + + /// For WASIp2 this retreives the result from the future passed to + /// `set_pending_connect`. + /// + /// Return states here are: + /// + /// * `Ok(Some(res))` - where `res` is the result of the connect operation. + /// * `Ok(None)` - the connect operation isn't ready yet. + /// * `Err(e)` - a connect operation is not in progress. + pub(crate) fn take_pending_connect( + &mut self, + ) -> Result>, ErrorCode> { + match mem::replace(&mut self.tcp_state, TcpState::Connecting(None)) { + TcpState::ConnectReady(result) => Ok(Some(result)), + TcpState::Connecting(Some(mut future)) => { + let mut cx = Context::from_waker(Waker::noop()); + match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) { + Poll::Ready(result) => Ok(Some(result)), + Poll::Pending => { + self.tcp_state = TcpState::Connecting(Some(future)); + Ok(None) + } + } + } + current_state => { + self.tcp_state = current_state; + Err(ErrorCode::NotInProgress) + } + } + } + + pub(crate) fn finish_connect( + &mut self, + result: io::Result, + ) -> Result<(), ErrorCode> { + if !matches!(self.tcp_state, TcpState::Connecting(None)) { + return Err(ErrorCode::InvalidState); + } + match result { + Ok(stream) => { + self.tcp_state = TcpState::Connected(Arc::new(stream)); + Ok(()) + } + Err(err) => { + self.tcp_state = TcpState::Closed; + Err(ErrorCode::from(err)) + } + } + } + + pub(crate) fn start_listen(&mut self) -> Result<(), ErrorCode> { + match mem::replace(&mut self.tcp_state, TcpState::Closed) { + TcpState::Bound(tokio_socket) => { + self.tcp_state = TcpState::ListenStarted(tokio_socket); + Ok(()) + } + previous_state => { + self.tcp_state = previous_state; + Err(ErrorCode::InvalidState) + } + } + } + + pub(crate) fn finish_listen(&mut self) -> Result<(), ErrorCode> { + let tokio_socket = match mem::replace(&mut self.tcp_state, TcpState::Closed) { + TcpState::ListenStarted(tokio_socket) => tokio_socket, + previous_state => { + self.tcp_state = previous_state; + return Err(ErrorCode::NotInProgress); + } + }; + + match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) { + Ok(listener) => { + self.tcp_state = TcpState::Listening { + listener: Arc::new(listener), + pending_accept: None, + }; + Ok(()) + } + Err(err) => { + self.tcp_state = TcpState::Closed; + + Err(match Errno::from_io_error(&err) { + // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE + // According to the docs, `listen` can return EMFILE on Windows. + // This is odd, because we're not trying to create a new socket + // or file descriptor of any kind. So we rewrite it to less + // surprising error code. + // + // At the time of writing, this behavior has never been experimentally + // observed by any of the wasmtime authors, so we're relying fully + // on Microsoft's documentation here. + #[cfg(windows)] + Some(Errno::MFILE) => Errno::NOBUFS.into(), + + _ => err.into(), + }) + } + } + } + + pub(crate) fn accept(&mut self) -> Result, ErrorCode> { + let TcpState::Listening { + listener, + pending_accept, + } = &mut self.tcp_state + else { + return Err(ErrorCode::InvalidState); + }; + + let result = match pending_accept.take() { + Some(result) => result, + None => { + let mut cx = std::task::Context::from_waker(Waker::noop()); + match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx)) + .map_ok(|(stream, _)| stream) + { + Poll::Ready(result) => result, + Poll::Pending => return Ok(None), + } + } + }; + + Ok(Some(Self::new_accept(result, &self.options, self.family)?)) + } + + #[cfg(feature = "p3")] + pub(crate) fn start_receive(&mut self) -> Option<&Arc> { + match mem::replace(&mut self.tcp_state, TcpState::Closed) { + TcpState::Connected(stream) => { + self.tcp_state = TcpState::Receiving(stream); + Some(self.tcp_stream_arc().unwrap()) + } + prev => { + self.tcp_state = prev; + None + } + } + } + + pub(crate) fn local_address(&self) -> Result { + match &self.tcp_state { + TcpState::Bound(socket) => Ok(socket.local_addr()?), + TcpState::Connected(stream) => Ok(stream.local_addr()?), + #[cfg(feature = "p3")] + TcpState::Receiving(stream) => Ok(stream.local_addr()?), + TcpState::P2Streaming(state) => Ok(state.stream.local_addr()?), + TcpState::Listening { listener, .. } => Ok(listener.local_addr()?), + #[cfg(feature = "p3")] + TcpState::Error(err) => Err(err.into()), + _ => Err(ErrorCode::InvalidState), + } + } + + pub(crate) fn remote_address(&self) -> Result { + let stream = self.tcp_stream_arc()?; + let addr = stream.peer_addr()?; + Ok(addr) + } + + pub(crate) fn is_listening(&self) -> bool { + matches!(self.tcp_state, TcpState::Listening { .. }) + } + + pub(crate) fn address_family(&self) -> SocketAddressFamily { + self.family + } + + pub(crate) fn set_listen_backlog_size(&mut self, value: u64) -> Result<(), ErrorCode> { + const MIN_BACKLOG: u32 = 1; + const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further. + + if value == 0 { + return Err(ErrorCode::InvalidArgument); + } + // Silently clamp backlog size. This is OK for us to do, because operating systems do this too. + let value = value + .try_into() + .unwrap_or(MAX_BACKLOG) + .clamp(MIN_BACKLOG, MAX_BACKLOG); + match &self.tcp_state { + TcpState::Default(..) | TcpState::Bound(..) => { + // Socket not listening yet. Stash value for first invocation to `listen`. + self.listen_backlog_size = value; + Ok(()) + } + TcpState::Listening { listener, .. } => { + // Try to update the backlog by calling `listen` again. + // Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact. + if rustix::net::listen(&listener, value.try_into().unwrap_or(i32::MAX)).is_err() { + return Err(ErrorCode::NotSupported); + } + self.listen_backlog_size = value; + Ok(()) + } + #[cfg(feature = "p3")] + TcpState::Error(err) => Err(err.into()), + _ => Err(ErrorCode::InvalidState), + } + } + + pub(crate) fn keep_alive_enabled(&self) -> Result { + let fd = &*self.as_std_view()?; + let v = sockopt::socket_keepalive(fd)?; + Ok(v) + } + + pub(crate) fn set_keep_alive_enabled(&self, value: bool) -> Result<(), ErrorCode> { + let fd = &*self.as_std_view()?; + sockopt::set_socket_keepalive(fd, value)?; + Ok(()) + } + + pub(crate) fn keep_alive_idle_time(&self) -> Result { + let fd = &*self.as_std_view()?; + let v = sockopt::tcp_keepidle(fd)?; + Ok(v.as_nanos().try_into().unwrap_or(u64::MAX)) + } + + pub(crate) fn set_keep_alive_idle_time(&mut self, value: u64) -> Result<(), ErrorCode> { + let value = { + let fd = self.as_std_view()?; + set_keep_alive_idle_time(&*fd, value)? + }; + self.options.set_keep_alive_idle_time(value); + Ok(()) + } + + pub(crate) fn keep_alive_interval(&self) -> Result { + let fd = &*self.as_std_view()?; + let v = sockopt::tcp_keepintvl(fd)?; + Ok(v.as_nanos().try_into().unwrap_or(u64::MAX)) + } + + pub(crate) fn set_keep_alive_interval(&self, value: u64) -> Result<(), ErrorCode> { + let fd = &*self.as_std_view()?; + set_keep_alive_interval(fd, Duration::from_nanos(value))?; + Ok(()) + } + + pub(crate) fn keep_alive_count(&self) -> Result { + let fd = &*self.as_std_view()?; + let v = sockopt::tcp_keepcnt(fd)?; + Ok(v) + } + + pub(crate) fn set_keep_alive_count(&self, value: u32) -> Result<(), ErrorCode> { + let fd = &*self.as_std_view()?; + set_keep_alive_count(fd, value)?; + Ok(()) + } + + pub(crate) fn hop_limit(&self) -> Result { + let fd = &*self.as_std_view()?; + let n = get_unicast_hop_limit(fd, self.family)?; + Ok(n) + } + + pub(crate) fn set_hop_limit(&mut self, value: u8) -> Result<(), ErrorCode> { + { + let fd = &*self.as_std_view()?; + set_unicast_hop_limit(fd, self.family, value)?; + } + self.options.set_hop_limit(value); + Ok(()) + } + + pub(crate) fn receive_buffer_size(&self) -> Result { + let fd = &*self.as_std_view()?; + let n = receive_buffer_size(fd)?; + Ok(n) + } + + pub(crate) fn set_receive_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> { + let res = { + let fd = &*self.as_std_view()?; + set_receive_buffer_size(fd, value)? + }; + self.options.set_receive_buffer_size(res); + Ok(()) + } + + pub(crate) fn send_buffer_size(&self) -> Result { + let fd = &*self.as_std_view()?; + let n = send_buffer_size(fd)?; + Ok(n) + } + + pub(crate) fn set_send_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> { + let res = { + let fd = &*self.as_std_view()?; + set_send_buffer_size(fd, value)? + }; + self.options.set_send_buffer_size(res); + Ok(()) + } + + #[cfg(feature = "p3")] + pub(crate) fn non_inherited_options(&self) -> &NonInheritedOptions { + &self.options + } + + #[cfg(feature = "p3")] + pub(crate) fn tcp_listener_arc(&self) -> Result<&Arc, ErrorCode> { + match &self.tcp_state { + TcpState::Listening { listener, .. } => Ok(listener), + #[cfg(feature = "p3")] + TcpState::Error(err) => Err(err.into()), + _ => Err(ErrorCode::InvalidState), + } + } + + pub(crate) fn tcp_stream_arc(&self) -> Result<&Arc, ErrorCode> { + match &self.tcp_state { + TcpState::Connected(socket) => Ok(socket), + #[cfg(feature = "p3")] + TcpState::Receiving(socket) => Ok(socket), + TcpState::P2Streaming(state) => Ok(&state.stream), + #[cfg(feature = "p3")] + TcpState::Error(err) => Err(err.into()), + _ => Err(ErrorCode::InvalidState), + } + } + + pub(crate) fn p2_streaming_state(&self) -> Result<&P2TcpStreamingState, ErrorCode> { + match &self.tcp_state { + TcpState::P2Streaming(state) => Ok(state), + #[cfg(feature = "p3")] + TcpState::Error(err) => Err(err.into()), + _ => Err(ErrorCode::InvalidState), + } + } + + pub(crate) fn set_p2_streaming_state( + &mut self, + state: P2TcpStreamingState, + ) -> Result<(), ErrorCode> { + if !matches!(self.tcp_state, TcpState::Connected(_)) { + return Err(ErrorCode::InvalidState); + } + self.tcp_state = TcpState::P2Streaming(Box::new(state)); + Ok(()) + } + + /// Used for `Pollable` in the WASIp2 implementation this awaits the socket + /// to be connected, if in the connecting state, or for a TCP accept to be + /// ready, if this is in the listening state. + /// + /// For all other states this method immediately returns. + pub(crate) async fn ready(&mut self) { + match &mut self.tcp_state { + TcpState::Default(..) + | TcpState::BindStarted(..) + | TcpState::Bound(..) + | TcpState::ListenStarted(..) + | TcpState::ConnectReady(..) + | TcpState::Closed + | TcpState::Connected { .. } + | TcpState::Connecting(None) + | TcpState::Listening { + pending_accept: Some(_), + .. + } + | TcpState::P2Streaming(_) => {} + + #[cfg(feature = "p3")] + TcpState::Receiving(_) | TcpState::Error(_) => {} + + TcpState::Connecting(Some(future)) => { + self.tcp_state = TcpState::ConnectReady(future.as_mut().await); + } + + TcpState::Listening { + listener, + pending_accept: slot @ None, + } => { + let result = futures::future::poll_fn(|cx| { + listener.poll_accept(cx).map_ok(|(stream, _)| stream) + }) + .await; + *slot = Some(result); + } + } + } +} + +#[cfg(not(target_os = "macos"))] +pub use inherits_option::*; +#[cfg(not(target_os = "macos"))] +mod inherits_option { + use crate::sockets::SocketAddressFamily; + use tokio::net::TcpStream; + + #[derive(Default, Clone)] + pub struct NonInheritedOptions; + + impl NonInheritedOptions { + pub fn set_keep_alive_idle_time(&mut self, _value: u64) {} + + pub fn set_hop_limit(&mut self, _value: u8) {} + + pub fn set_receive_buffer_size(&mut self, _value: usize) {} + + pub fn set_send_buffer_size(&mut self, _value: usize) {} + + pub(crate) fn apply(&self, _family: SocketAddressFamily, _stream: &TcpStream) {} + } +} + +#[cfg(target_os = "macos")] +pub use does_not_inherit_options::*; +#[cfg(target_os = "macos")] +mod does_not_inherit_options { + use crate::sockets::SocketAddressFamily; + use rustix::net::sockopt; + use std::sync::Arc; + use std::sync::atomic::{AtomicU8, AtomicU64, AtomicUsize, Ordering::Relaxed}; + use std::time::Duration; + use tokio::net::TcpStream; + + // The socket options below are not automatically inherited from the listener + // on all platforms. So we keep track of which options have been explicitly + // set and manually apply those values to newly accepted clients. + #[derive(Default, Clone)] + pub struct NonInheritedOptions(Arc); + + #[derive(Default)] + struct Inner { + receive_buffer_size: AtomicUsize, + send_buffer_size: AtomicUsize, + hop_limit: AtomicU8, + keep_alive_idle_time: AtomicU64, // nanoseconds + } + + impl NonInheritedOptions { + pub fn set_keep_alive_idle_time(&mut self, value: u64) { + self.0.keep_alive_idle_time.store(value, Relaxed); + } + + pub fn set_hop_limit(&mut self, value: u8) { + self.0.hop_limit.store(value, Relaxed); + } + + pub fn set_receive_buffer_size(&mut self, value: usize) { + self.0.receive_buffer_size.store(value, Relaxed); + } + + pub fn set_send_buffer_size(&mut self, value: usize) { + self.0.send_buffer_size.store(value, Relaxed); + } + + pub(crate) fn apply(&self, family: SocketAddressFamily, stream: &TcpStream) { + // Manually inherit socket options from listener. We only have to + // do this on platforms that don't already do this automatically + // and only if a specific value was explicitly set on the listener. + + let receive_buffer_size = self.0.receive_buffer_size.load(Relaxed); + if receive_buffer_size > 0 { + // Ignore potential error. + _ = sockopt::set_socket_recv_buffer_size(&stream, receive_buffer_size); + } + + let send_buffer_size = self.0.send_buffer_size.load(Relaxed); + if send_buffer_size > 0 { + // Ignore potential error. + _ = sockopt::set_socket_send_buffer_size(&stream, send_buffer_size); + } + + // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't. + if family == SocketAddressFamily::Ipv6 { + let hop_limit = self.0.hop_limit.load(Relaxed); + if hop_limit > 0 { + // Ignore potential error. + _ = sockopt::set_ipv6_unicast_hops(&stream, Some(hop_limit)); + } + } + + let keep_alive_idle_time = self.0.keep_alive_idle_time.load(Relaxed); + if keep_alive_idle_time > 0 { + // Ignore potential error. + _ = sockopt::set_tcp_keepidle(&stream, Duration::from_nanos(keep_alive_idle_time)); + } + } + } +} diff --git a/crates/wasi/src/sockets/util.rs b/crates/wasi/src/sockets/util.rs index cb1043f48f3e..a9679a09d68d 100644 --- a/crates/wasi/src/sockets/util.rs +++ b/crates/wasi/src/sockets/util.rs @@ -28,6 +28,7 @@ pub enum ErrorCode { ConnectionAborted, DatagramTooLarge, NotInProgress, + ConcurrencyConflict, } impl fmt::Display for ErrorCode {