From d9725f3833399c8cc9d064f1188c6ea9e1d9b432 Mon Sep 17 00:00:00 2001 From: Ben Schofield Date: Mon, 25 May 2026 02:11:31 +0000 Subject: [PATCH 1/2] feat: add custom socket transport support for Postgres and MySQL Add methods to PgConnection and MySqlConnection that accept pre-connected sockets implementing AsyncRead + AsyncWrite, enabling custom transport layers (vsock, QUIC, turmoil, SSH tunnels, etc.) without forking sqlx. Per maintainer feedback on #4187, this uses AsyncRead + AsyncWrite traits instead of exposing the internal Socket trait. Two separate methods are provided for each runtime's trait set: - connect_with_custom_tokio(): accepts tokio::io::{AsyncRead, AsyncWrite} - connect_with_custom_futures(): accepts futures_io::{AsyncRead, AsyncWrite} Also adds PoolOptions::connector() so pools can use custom transports: PgPoolOptions::new() .connector(|options| async move { let socket = VsockStream::connect(addr).await?; PgConnection::connect_with_custom_tokio(socket, &options).await }) .connect_with(options) .await? TLS upgrade is negotiated automatically based on the connection options. No new public trait exposure. No behavioral changes to existing code. --- Cargo.toml | 8 +- sqlx-core/src/net/mod.rs | 6 + sqlx-core/src/net/socket/async_rw_adapter.rs | 550 +++++++++++++++++++ sqlx-core/src/net/socket/mod.rs | 2 + sqlx-core/src/pool/inner.rs | 10 +- sqlx-core/src/pool/options.rs | 52 ++ sqlx-mysql/Cargo.toml | 5 + sqlx-mysql/src/connection/establish.rs | 76 +++ sqlx-postgres/Cargo.toml | 10 + sqlx-postgres/src/connection/establish.rs | 62 ++- sqlx-postgres/src/connection/stream.rs | 17 + 11 files changed, 791 insertions(+), 7 deletions(-) create mode 100644 sqlx-core/src/net/socket/async_rw_adapter.rs diff --git a/Cargo.toml b/Cargo.toml index 3738b814ca..30f73f39cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -94,10 +94,10 @@ _unstable-docs = [ ] # Base runtime features without TLS -runtime-async-global-executor = ["_rt-async-global-executor", "sqlx-core/_rt-async-global-executor", "sqlx-macros?/_rt-async-global-executor"] -runtime-async-std = ["_rt-async-std", "sqlx-core/_rt-async-std", "sqlx-macros?/_rt-async-std"] -runtime-smol = ["_rt-smol", "sqlx-core/_rt-smol", "sqlx-macros?/_rt-smol"] -runtime-tokio = ["_rt-tokio", "sqlx-core/_rt-tokio", "sqlx-macros?/_rt-tokio"] +runtime-async-global-executor = ["_rt-async-global-executor", "sqlx-core/_rt-async-global-executor", "sqlx-macros?/_rt-async-global-executor", "sqlx-postgres?/_rt-async-io", "sqlx-mysql?/_rt-async-io"] +runtime-async-std = ["_rt-async-std", "sqlx-core/_rt-async-std", "sqlx-macros?/_rt-async-std", "sqlx-postgres?/_rt-async-io", "sqlx-mysql?/_rt-async-io"] +runtime-smol = ["_rt-smol", "sqlx-core/_rt-smol", "sqlx-macros?/_rt-smol", "sqlx-postgres?/_rt-async-io", "sqlx-mysql?/_rt-async-io"] +runtime-tokio = ["_rt-tokio", "sqlx-core/_rt-tokio", "sqlx-macros?/_rt-tokio", "sqlx-postgres?/_rt-tokio", "sqlx-mysql?/_rt-tokio"] # TLS features tls-native-tls = ["sqlx-core/_tls-native-tls", "sqlx-macros?/_tls-native-tls"] diff --git a/sqlx-core/src/net/mod.rs b/sqlx-core/src/net/mod.rs index f9c43668ab..e9da7b427c 100644 --- a/sqlx-core/src/net/mod.rs +++ b/sqlx-core/src/net/mod.rs @@ -4,3 +4,9 @@ pub mod tls; pub use socket::{ connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer, }; + +#[cfg(feature = "_rt-tokio")] +pub use socket::async_rw_adapter::TokioStream; + +#[cfg(feature = "_rt-async-io")] +pub use socket::async_rw_adapter::FuturesStream; diff --git a/sqlx-core/src/net/socket/async_rw_adapter.rs b/sqlx-core/src/net/socket/async_rw_adapter.rs new file mode 100644 index 0000000000..08cf453234 --- /dev/null +++ b/sqlx-core/src/net/socket/async_rw_adapter.rs @@ -0,0 +1,550 @@ +//! Adapters that bridge `AsyncRead + AsyncWrite` implementations into sqlx's internal [`Socket`] trait. +//! +//! These adapters exist so users can pass pre-connected streams (vsock, QUIC, turmoil, etc.) +//! to sqlx without exposing the `Socket` trait as public API. +//! +//! ## Design notes +//! +//! The [`Socket`] trait uses a split-phase read model: `poll_read_ready` signals data is available, +//! then `try_read` synchronously copies from an internal buffer. Since `AsyncRead` doesn't have a +//! separate readiness notification, `poll_read_ready` performs the actual read into an internal +//! buffer, and `try_read` drains from it. +//! +//! `try_write` uses a noop waker to attempt a non-blocking poll_write. This is safe because +//! the caller (`Write` future in `socket/mod.rs`) always calls `poll_write_ready(cx)` with the +//! real task waker when `try_write` returns `WouldBlock`, ensuring proper wakeup registration. + +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::BufMut; + +use crate::io::ReadBuf; +use crate::net::Socket; + +/// Internal buffer size for the read-ahead used by `poll_read_ready`. +const ADAPTER_BUF_SIZE: usize = 8192; + +/// Generates an adapter struct + `Socket` impl for a given async I/O trait family. +macro_rules! impl_socket_adapter { + ( + $(#[$meta:meta])* + $name:ident, + feature = $feature:literal, + bounds($($bound:path),+), + poll_read($self:ident, $cx:ident, $buf:ident) => $poll_read_expr:expr, + poll_write($self_w:ident, $cx_w:ident, $buf_w:ident) => $poll_write_expr:expr, + poll_flush($self_f:ident, $cx_f:ident) => $poll_flush_expr:expr, + poll_shutdown($self_s:ident, $cx_s:ident) => $poll_shutdown_expr:expr $(,)? + ) => { + $(#[$meta])* + #[cfg(feature = $feature)] + pub struct $name { + inner: S, + read_buf: Vec, + read_len: usize, + read_pos: usize, + } + + #[cfg(feature = $feature)] + impl $name { + pub fn new(inner: S) -> Self { + Self { + inner, + read_buf: vec![0u8; ADAPTER_BUF_SIZE], + read_len: 0, + read_pos: 0, + } + } + + fn buffered(&self) -> &[u8] { + &self.read_buf[self.read_pos..self.read_len] + } + } + + #[cfg(feature = $feature)] + impl Socket for $name + where + S: $($bound +)+ Send + Sync + Unpin + 'static, + { + fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result { + let buffered = self.buffered(); + if !buffered.is_empty() { + let to_copy = std::cmp::min(buffered.len(), buf.remaining_mut()); + buf.put_slice(&buffered[..to_copy]); + self.read_pos += to_copy; + if self.read_pos == self.read_len { + self.read_pos = 0; + self.read_len = 0; + } + return Ok(to_copy); + } + Err(io::Error::from(io::ErrorKind::WouldBlock)) + } + + fn try_write(&mut self, buf: &[u8]) -> io::Result { + // Safe to use noop_waker here: if Pending is returned, the caller + // (Socket::write future) will call poll_write_ready(cx) with the real + // task context, which re-registers the proper waker. + let waker = futures_util::task::noop_waker(); + let mut cx = Context::from_waker(&waker); + let $self_w = &mut self.inner; + let $buf_w = buf; + let $cx_w = &mut cx; + match $poll_write_expr { + Poll::Ready(result) => result, + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } + } + + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if !self.buffered().is_empty() { + return Poll::Ready(Ok(())); + } + + self.read_pos = 0; + self.read_len = 0; + + let $cx = cx; + let $self = &mut self.inner; + let $buf = &mut self.read_buf; + match $poll_read_expr { + Poll::Ready(Ok(n)) => { + if n == 0 { + return Poll::Ready(Err(io::Error::from( + io::ErrorKind::UnexpectedEof, + ))); + } + self.read_len = n; + Poll::Ready(Ok(())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } + + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + // Attempt a zero-byte write to check if the underlying stream is writable. + // This registers the real waker with the I/O resource so we get woken + // when the socket becomes writable. + let $cx_w = cx; + let $self_w = &mut self.inner; + let $buf_w: &[u8] = &[]; + match $poll_write_expr { + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } + + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + let $cx_f = cx; + let $self_f = &mut self.inner; + $poll_flush_expr + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + let $cx_s = cx; + let $self_s = &mut self.inner; + $poll_shutdown_expr + } + } + }; +} + +impl_socket_adapter! { + /// Adapter that wraps a tokio [`AsyncRead`][tokio::io::AsyncRead] + + /// [`AsyncWrite`][tokio::io::AsyncWrite] into a [`Socket`] implementation. + TokioStream, + feature = "_rt-tokio", + bounds(tokio::io::AsyncRead, tokio::io::AsyncWrite), + poll_read(inner, cx, buf) => { + let mut read_buf = tokio::io::ReadBuf::new(buf); + match Pin::new(inner).poll_read(cx, &mut read_buf) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + }, + poll_write(inner, cx, buf) => { + Pin::new(inner).poll_write(cx, buf) + }, + poll_flush(inner, cx) => { + Pin::new(inner).poll_flush(cx) + }, + poll_shutdown(inner, cx) => { + Pin::new(inner).poll_shutdown(cx) + }, +} + +impl_socket_adapter! { + /// Adapter that wraps a futures-io [`AsyncRead`][futures_io::AsyncRead] + + /// [`AsyncWrite`][futures_io::AsyncWrite] into a [`Socket`] implementation. + FuturesStream, + feature = "_rt-async-io", + bounds(futures_io::AsyncRead, futures_io::AsyncWrite), + poll_read(inner, cx, buf) => { + Pin::new(inner).poll_read(cx, buf) + }, + poll_write(inner, cx, buf) => { + Pin::new(inner).poll_write(cx, buf) + }, + poll_flush(inner, cx) => { + Pin::new(inner).poll_flush(cx) + }, + poll_shutdown(inner, cx) => { + Pin::new(inner).poll_close(cx) + }, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "_rt-tokio")] + mod tokio_adapter { + use super::*; + use crate::net::Socket; + use bytes::BytesMut; + use std::task::Poll; + + #[test] + fn try_read_returns_would_block_when_empty() { + let stream = tokio::io::duplex(64).0; + let mut adapter = TokioStream::new(stream); + let mut buf = BytesMut::with_capacity(32); + let err = adapter.try_read(&mut buf).unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::WouldBlock); + } + + #[test] + fn poll_read_ready_fills_buffer_then_try_read_drains() { + let (client, mut server) = tokio::io::duplex(64); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async { + use tokio::io::AsyncWriteExt; + server.write_all(b"hello world").await.unwrap(); + + let mut adapter = TokioStream::new(client); + let mut buf = BytesMut::with_capacity(32); + + let poll = std::future::poll_fn(|cx| adapter.poll_read_ready(cx)).await; + assert!(poll.is_ok()); + + let n = adapter.try_read(&mut buf).unwrap(); + assert_eq!(&buf[..n], b"hello world"); + + let mut buf2 = BytesMut::with_capacity(32); + let err = adapter.try_read(&mut buf2).unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::WouldBlock); + }); + } + + #[test] + fn try_write_writes_data() { + let (client, mut server) = tokio::io::duplex(64); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async { + use tokio::io::AsyncReadExt; + let mut adapter = TokioStream::new(client); + + let n = std::future::poll_fn(|cx| match adapter.try_write(b"test data") { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + match adapter.poll_write_ready(cx) { + Poll::Ready(Ok(())) => Poll::Ready(adapter.try_write(b"test data")), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } + other => Poll::Ready(other), + }) + .await + .unwrap(); + + assert_eq!(n, 9); + + let mut read_buf = vec![0u8; 32]; + let n = server.read(&mut read_buf).await.unwrap(); + assert_eq!(&read_buf[..n], b"test data"); + }); + } + + #[test] + fn partial_drain_preserves_remaining() { + let (client, mut server) = tokio::io::duplex(64); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async { + use tokio::io::AsyncWriteExt; + server.write_all(b"abcdefghij").await.unwrap(); + + let mut adapter = TokioStream::new(client); + + std::future::poll_fn(|cx| adapter.poll_read_ready(cx)) + .await + .unwrap(); + + let mut buf = [0u8; 4]; + let n = adapter.try_read(&mut buf.as_mut_slice()).unwrap(); + assert_eq!(n, 4); + assert_eq!(&buf, b"abcd"); + + let mut buf2 = [0u8; 32]; + let n = adapter.try_read(&mut buf2.as_mut_slice()).unwrap(); + assert_eq!(n, 6); + assert_eq!(&buf2[..6], b"efghij"); + }); + } + + #[test] + fn poll_read_ready_returns_eof_on_closed_stream() { + let (client, server) = tokio::io::duplex(64); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async { + // Drop the server side to close the stream + drop(server); + + let mut adapter = TokioStream::new(client); + let err = std::future::poll_fn(|cx| adapter.poll_read_ready(cx)) + .await + .unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof); + }); + } + + #[test] + fn large_data_spans_multiple_buffer_fills() { + let (client, mut server) = tokio::io::duplex(64 * 1024); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async { + use tokio::io::AsyncWriteExt; + + // Write more than ADAPTER_BUF_SIZE (8192) bytes + let data: Vec = (0..20_000).map(|i| (i % 256) as u8).collect(); + server.write_all(&data).await.unwrap(); + + let mut adapter = TokioStream::new(client); + let mut received = BytesMut::with_capacity(20_000); + + // Read all data through multiple poll_read_ready/try_read cycles + while received.len() < 20_000 { + std::future::poll_fn(|cx| adapter.poll_read_ready(cx)) + .await + .unwrap(); + // Drain everything available in the internal buffer + loop { + match adapter.try_read(&mut received) { + Ok(_) => {} + Err(e) if e.kind() == io::ErrorKind::WouldBlock => break, + Err(e) => panic!("unexpected error: {e}"), + } + } + } + + assert_eq!(received.len(), 20_000); + assert_eq!(&received[..], &data[..]); + }); + } + + } + + #[cfg(feature = "_rt-async-io")] + mod futures_adapter { + use super::*; + use crate::net::Socket; + use bytes::BytesMut; + use std::task::Poll; + + /// A simple in-memory duplex using futures_io traits via Cursor. + /// We use `futures_util::io::Cursor` which implements AsyncRead + AsyncWrite. + struct MemStream { + /// Data available for reading + read_data: std::io::Cursor>, + /// Written data collected here + write_data: Vec, + } + + impl MemStream { + fn new(data: &[u8]) -> Self { + Self { + read_data: std::io::Cursor::new(data.to_vec()), + write_data: Vec::new(), + } + } + } + + impl futures_io::AsyncRead for MemStream { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + use std::io::Read; + let n = self.read_data.read(buf)?; + Poll::Ready(Ok(n)) + } + } + + impl futures_io::AsyncWrite for MemStream { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.write_data.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } + + #[test] + fn try_read_returns_would_block_when_empty() { + let stream = MemStream::new(b""); + let mut adapter = FuturesStream::new(stream); + let mut buf = BytesMut::with_capacity(32); + // Empty stream: poll_read_ready returns UnexpectedEof, try_read returns WouldBlock + let err = adapter.try_read(&mut buf).unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::WouldBlock); + } + + #[test] + fn poll_read_ready_fills_buffer_then_try_read_drains() { + let stream = MemStream::new(b"hello futures"); + let mut adapter = FuturesStream::new(stream); + let mut buf = BytesMut::with_capacity(32); + + let waker = futures_util::task::noop_waker(); + let mut cx = Context::from_waker(&waker); + + // poll_read_ready should fill internal buffer + match adapter.poll_read_ready(&mut cx) { + Poll::Ready(Ok(())) => {} + other => panic!("expected Ready(Ok(())), got {:?}", other), + } + + // try_read should drain from internal buffer + let n = adapter.try_read(&mut buf).unwrap(); + assert_eq!(&buf[..n], b"hello futures"); + + // After draining, try_read should return WouldBlock + let mut buf2 = BytesMut::with_capacity(32); + let err = adapter.try_read(&mut buf2).unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::WouldBlock); + } + + #[test] + fn try_write_writes_data() { + let stream = MemStream::new(b""); + let mut adapter = FuturesStream::new(stream); + + let n = adapter.try_write(b"test data").unwrap(); + assert_eq!(n, 9); + } + + #[test] + fn partial_drain_preserves_remaining() { + let stream = MemStream::new(b"abcdefghij"); + let mut adapter = FuturesStream::new(stream); + + let waker = futures_util::task::noop_waker(); + let mut cx = Context::from_waker(&waker); + + // Fill internal buffer + match adapter.poll_read_ready(&mut cx) { + Poll::Ready(Ok(())) => {} + other => panic!("expected Ready(Ok(())), got {:?}", other), + } + + // Read only 4 bytes + let mut buf = [0u8; 4]; + let n = adapter.try_read(&mut buf.as_mut_slice()).unwrap(); + assert_eq!(n, 4); + assert_eq!(&buf, b"abcd"); + + // Remaining 6 bytes should still be available + let mut buf2 = [0u8; 32]; + let n = adapter.try_read(&mut buf2.as_mut_slice()).unwrap(); + assert_eq!(n, 6); + assert_eq!(&buf2[..6], b"efghij"); + } + + #[test] + fn poll_read_ready_returns_eof_on_empty_stream() { + // MemStream with empty data simulates a closed/EOF stream + let stream = MemStream::new(b""); + let mut adapter = FuturesStream::new(stream); + + let waker = futures_util::task::noop_waker(); + let mut cx = Context::from_waker(&waker); + + match adapter.poll_read_ready(&mut cx) { + Poll::Ready(Err(e)) => assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof), + other => panic!("expected UnexpectedEof, got {:?}", other), + } + } + + #[test] + fn large_data_spans_multiple_buffer_fills() { + // Write more than ADAPTER_BUF_SIZE (8192) bytes + let data: Vec = (0..20_000).map(|i| (i % 256) as u8).collect(); + let stream = MemStream::new(&data); + let mut adapter = FuturesStream::new(stream); + + let waker = futures_util::task::noop_waker(); + let mut cx = Context::from_waker(&waker); + + let mut received = BytesMut::with_capacity(20_000); + + // Read all data through multiple poll_read_ready/try_read cycles + while received.len() < 20_000 { + match adapter.poll_read_ready(&mut cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(e)) => panic!("unexpected error: {e}"), + Poll::Pending => panic!("unexpected Pending"), + } + loop { + match adapter.try_read(&mut received) { + Ok(_) => {} + Err(e) if e.kind() == io::ErrorKind::WouldBlock => break, + Err(e) => panic!("unexpected error: {e}"), + } + } + } + + assert_eq!(received.len(), 20_000); + assert_eq!(&received[..], &data[..]); + } + } +} diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs index 0f9aae61b4..1350f5d4d9 100644 --- a/sqlx-core/src/net/socket/mod.rs +++ b/sqlx-core/src/net/socket/mod.rs @@ -10,6 +10,8 @@ use cfg_if::cfg_if; use crate::io::ReadBuf; +#[cfg(any(feature = "_rt-tokio", feature = "_rt-async-io"))] +pub mod async_rw_adapter; mod buffered; pub trait Socket: Send + Sync + Unpin + 'static { diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index b698dc9df0..a4699b5606 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -5,6 +5,7 @@ use crate::database::Database; use crate::error::Error; use crate::pool::{deadline_as_timeout, CloseEvent, Pool, PoolOptions}; use crossbeam_queue::ArrayQueue; +use futures_core::future::BoxFuture; use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser}; @@ -347,7 +348,14 @@ impl PoolInner { // result here is `Result, TimeoutError>` // if this block does not return, sleep for the backoff timeout and try again - match crate::rt::timeout(timeout, connect_options.connect()).await { + let connect_fut: BoxFuture<'_, Result> = + if let Some(connector) = &self.options.connector { + connector(&connect_options) + } else { + Box::pin(connect_options.connect()) + }; + + match crate::rt::timeout(timeout, connect_fut).await { // successfully established connection Ok(Ok(mut raw)) => { // See comment on `PoolOptions::after_connect` diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index 3d048f1795..03970e39fd 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -6,6 +6,7 @@ use crate::pool::Pool; use futures_core::future::BoxFuture; use log::LevelFilter; use std::fmt::{self, Debug, Formatter}; +use std::future::Future; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -85,6 +86,16 @@ pub struct PoolOptions { pub(crate) fair: bool, pub(crate) parent_pool: Option>, + + pub(crate) connector: Option< + Arc< + dyn Fn( + &::Options, + ) -> BoxFuture<'static, Result> + + Send + + Sync, + >, + >, } // Manually implement `Clone` to avoid a trait bound issue. @@ -107,6 +118,7 @@ impl Clone for PoolOptions { idle_timeout: self.idle_timeout, fair: self.fair, parent_pool: self.parent_pool.clone(), + connector: self.connector.clone(), } } } @@ -162,6 +174,7 @@ impl PoolOptions { max_lifetime: Some(Duration::from_secs(30 * 60)), fair: true, parent_pool: None, + connector: None, } } @@ -513,6 +526,45 @@ impl PoolOptions { self } + /// Set a custom connector that the pool calls to create new connections. + /// + /// This overrides the default behavior of calling [`ConnectOptions::connect()`]. + /// Use this to connect over custom transports (vsock, QUIC, turmoil, SSH tunnels, etc.). + /// + /// The closure receives a clone of the current [`ConnectOptions`][Connection::Options] + /// and must return a future resolving to a new connection. + /// + /// # Example + /// + /// ```ignore + /// use sqlx::postgres::{PgConnectOptions, PgConnection, PgPoolOptions}; + /// + /// # async fn example() -> Result<(), sqlx::Error> { + /// let pool = PgPoolOptions::new() + /// .max_connections(5) + /// .connector(|options| async move { + /// let stream = tokio::net::TcpStream::connect("127.0.0.1:5432").await?; + /// PgConnection::connect_raw_tokio(stream, &options).await + /// }) + /// .connect_with( + /// PgConnectOptions::new().username("postgres").database("mydb"), + /// ) + /// .await?; + /// # Ok(()) + /// # } + /// ``` + pub fn connector(mut self, connector: F) -> Self + where + F: Fn(::Options) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + DB::Connection: Sized, + { + self.connector = Some(Arc::new(move |options| { + Box::pin(connector(options.clone())) + })); + self + } + /// Create a new pool from this `PoolOptions` and immediately open at least one connection. /// /// This ensures the configuration is correct. diff --git a/sqlx-mysql/Cargo.toml b/sqlx-mysql/Cargo.toml index 7ffb529c8f..fa9cea96f9 100644 --- a/sqlx-mysql/Cargo.toml +++ b/sqlx-mysql/Cargo.toml @@ -16,6 +16,8 @@ any = ["sqlx-core/any"] offline = ["sqlx-core/offline", "serde/derive", "bitflags/serde"] migrate = ["sqlx-core/migrate", "dep:crc"] rsa = ["dep:rand", "dep:rsa"] +_rt-tokio = ["sqlx-core/_rt-tokio", "dep:tokio"] +_rt-async-io = ["sqlx-core/_rt-async-io", "dep:futures-io"] # Type Integration features bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal"] @@ -61,6 +63,9 @@ thiserror.workspace = true serde = { version = "1.0.219", optional = true } +tokio = { workspace = true, optional = true } +futures-io = { version = "0.3.32", optional = true } + [dev-dependencies] # FIXME: https://github.com/rust-lang/cargo/issues/15622 sqlx = { path = "..", default-features = false, features = ["mysql"] } diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index f61654d876..6cdcb338db 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -32,6 +32,82 @@ impl MySqlConnection { }), }) } + + /// Connect to a MySQL server over a pre-connected stream implementing + /// tokio's [`AsyncRead`][tokio::io::AsyncRead] + [`AsyncWrite`][tokio::io::AsyncWrite]. + /// + /// This allows using custom transport layers (e.g., vsock for AWS Nitro Enclaves, + /// QUIC streams, simulation frameworks like `turmoil`, or proxied connections) + /// without forking sqlx. + /// + /// TLS upgrade is negotiated automatically based on `options.ssl_mode`. + /// + /// # Example + /// + /// ```no_run + /// use sqlx::mysql::{MySqlConnectOptions, MySqlConnection}; + /// + /// # async fn example() -> Result<(), sqlx::Error> { + /// let stream = tokio::net::TcpStream::connect("127.0.0.1:3306").await?; + /// let options = MySqlConnectOptions::new() + /// .username("root") + /// .database("mydb"); + /// let conn = MySqlConnection::connect_raw_tokio(stream, &options).await?; + /// # Ok(()) + /// # } + /// ``` + #[cfg(feature = "_rt-tokio")] + pub async fn connect_raw_tokio( + stream: S, + options: &MySqlConnectOptions, + ) -> Result + where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin + 'static, + { + use crate::net::TokioStream; + let do_handshake = DoHandshake::new(options)?; + let stream = do_handshake.do_handshake(TokioStream::new(stream)).await?; + Ok(Self { + inner: Box::new(MySqlConnectionInner { + stream, + transaction_depth: 0, + status_flags: Default::default(), + cache_statement: StatementCache::new(options.statement_cache_capacity), + log_settings: options.log_settings.clone(), + }), + }) + } + + /// Connect to a MySQL server over a pre-connected stream implementing + /// futures-io's [`AsyncRead`][futures_io::AsyncRead] + [`AsyncWrite`][futures_io::AsyncWrite]. + /// + /// This allows using custom transport layers (e.g., vsock, QUIC streams, + /// simulation frameworks, or proxied connections) without forking sqlx. + /// + /// TLS upgrade is negotiated automatically based on `options.ssl_mode`. + #[cfg(feature = "_rt-async-io")] + pub async fn connect_raw_futures( + stream: S, + options: &MySqlConnectOptions, + ) -> Result + where + S: futures_io::AsyncRead + futures_io::AsyncWrite + Send + Sync + Unpin + 'static, + { + use crate::net::FuturesStream; + let do_handshake = DoHandshake::new(options)?; + let stream = do_handshake + .do_handshake(FuturesStream::new(stream)) + .await?; + Ok(Self { + inner: Box::new(MySqlConnectionInner { + stream, + transaction_depth: 0, + status_flags: Default::default(), + cache_statement: StatementCache::new(options.statement_cache_capacity), + log_settings: options.log_settings.clone(), + }), + }) + } } struct DoHandshake<'a> { diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index d5bf41f9b1..c7600ed533 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -14,6 +14,8 @@ any = ["sqlx-core/any"] json = ["dep:serde", "dep:serde_json", "sqlx-core/json"] migrate = ["sqlx-core/migrate", "dep:crc"] offline = ["json", "sqlx-core/offline", "smallvec/serde"] +_rt-tokio = ["sqlx-core/_rt-tokio", "dep:tokio"] +_rt-async-io = ["sqlx-core/_rt-async-io", "dep:futures-io"] # Type Integration features bigdecimal = ["dep:bigdecimal", "dep:num-bigint", "sqlx-core/bigdecimal"] @@ -75,6 +77,14 @@ serde_json = { version = "1.0.142", optional = true, features = ["raw_value"] } [dependencies.sqlx-core] workspace = true +[dependencies.tokio] +workspace = true +optional = true + +[dependencies.futures-io] +version = "0.3.32" +optional = true + [dev-dependencies.sqlx] # FIXME: https://github.com/rust-lang/cargo/issues/15622 # workspace = true diff --git a/sqlx-postgres/src/connection/establish.rs b/sqlx-postgres/src/connection/establish.rs index 3c2f516533..52992f114a 100644 --- a/sqlx-postgres/src/connection/establish.rs +++ b/sqlx-postgres/src/connection/establish.rs @@ -16,9 +16,67 @@ use super::PgConnectionInner; impl PgConnection { pub(crate) async fn establish(options: &PgConnectOptions) -> Result { - // Upgrade to TLS if we were asked to and the server supports it - let mut stream = PgStream::connect(options).await?; + let stream = PgStream::connect(options).await?; + Self::establish_with_stream(stream, options).await + } + + /// Connect to a PostgreSQL server over a pre-connected stream implementing + /// tokio's [`AsyncRead`][tokio::io::AsyncRead] + [`AsyncWrite`][tokio::io::AsyncWrite]. + /// + /// This allows using custom transport layers (e.g., vsock for AWS Nitro Enclaves, + /// QUIC streams, simulation frameworks like `turmoil`, or proxied connections) + /// without forking sqlx. + /// + /// TLS upgrade is negotiated automatically based on `options.ssl_mode`. + /// + /// # Example + /// + /// ```no_run + /// use sqlx::postgres::{PgConnectOptions, PgConnection}; + /// + /// # async fn example() -> Result<(), sqlx::Error> { + /// let stream = tokio::net::TcpStream::connect("127.0.0.1:5432").await?; + /// let options = PgConnectOptions::new() + /// .username("postgres") + /// .database("mydb"); + /// let conn = PgConnection::connect_raw_tokio(stream, &options).await?; + /// # Ok(()) + /// # } + /// ``` + #[cfg(feature = "_rt-tokio")] + pub async fn connect_raw_tokio(stream: S, options: &PgConnectOptions) -> Result + where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin + 'static, + { + use crate::net::TokioStream; + let stream = PgStream::with_socket(TokioStream::new(stream), options).await?; + Self::establish_with_stream(stream, options).await + } + + /// Connect to a PostgreSQL server over a pre-connected stream implementing + /// futures-io's [`AsyncRead`][futures_io::AsyncRead] + [`AsyncWrite`][futures_io::AsyncWrite]. + /// + /// This allows using custom transport layers (e.g., vsock, QUIC streams, + /// simulation frameworks, or proxied connections) without forking sqlx. + /// + /// TLS upgrade is negotiated automatically based on `options.ssl_mode`. + #[cfg(feature = "_rt-async-io")] + pub async fn connect_raw_futures( + stream: S, + options: &PgConnectOptions, + ) -> Result + where + S: futures_io::AsyncRead + futures_io::AsyncWrite + Send + Sync + Unpin + 'static, + { + use crate::net::FuturesStream; + let stream = PgStream::with_socket(FuturesStream::new(stream), options).await?; + Self::establish_with_stream(stream, options).await + } + async fn establish_with_stream( + mut stream: PgStream, + options: &PgConnectOptions, + ) -> Result { // To begin a session, a frontend opens a connection to the server // and sends a startup message. diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs index e8a1aedc47..a6fb605c43 100644 --- a/sqlx-postgres/src/connection/stream.rs +++ b/sqlx-postgres/src/connection/stream.rs @@ -13,6 +13,8 @@ use crate::message::{ BackendMessage, BackendMessageFormat, EncodeMessage, FrontendMessage, Notice, Notification, ParameterStatus, ReceivedMessage, }; +#[cfg(any(feature = "_rt-tokio", feature = "_rt-async-io"))] +use crate::net::WithSocket; use crate::net::{self, BufferedSocket, Socket}; use crate::{PgConnectOptions, PgDatabaseError, PgSeverity}; @@ -57,6 +59,21 @@ impl PgStream { }) } + #[cfg(any(feature = "_rt-tokio", feature = "_rt-async-io"))] + pub(super) async fn with_socket( + socket: S, + options: &PgConnectOptions, + ) -> Result { + let socket = MaybeUpgradeTls(options).with_socket(socket).await?; + + Ok(Self { + inner: BufferedSocket::new(socket), + notifications: None, + parameter_statuses: BTreeMap::default(), + server_version_num: None, + }) + } + #[inline(always)] pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> { self.write(EncodeMessage(message)) From 3bb6d73bfd880da8ba1be25af926541d166823f8 Mon Sep 17 00:00:00 2001 From: Ben Schofield Date: Mon, 25 May 2026 18:03:12 +0000 Subject: [PATCH 2/2] test: add integration tests for connect_raw_tokio with and without TLS --- sqlx-core/src/net/socket/async_rw_adapter.rs | 313 ++++++++++--------- tests/postgres/postgres.rs | 120 +++++++ 2 files changed, 286 insertions(+), 147 deletions(-) diff --git a/sqlx-core/src/net/socket/async_rw_adapter.rs b/sqlx-core/src/net/socket/async_rw_adapter.rs index 08cf453234..18856e5ed1 100644 --- a/sqlx-core/src/net/socket/async_rw_adapter.rs +++ b/sqlx-core/src/net/socket/async_rw_adapter.rs @@ -26,176 +26,196 @@ use crate::net::Socket; /// Internal buffer size for the read-ahead used by `poll_read_ready`. const ADAPTER_BUF_SIZE: usize = 8192; -/// Generates an adapter struct + `Socket` impl for a given async I/O trait family. -macro_rules! impl_socket_adapter { - ( - $(#[$meta:meta])* - $name:ident, - feature = $feature:literal, - bounds($($bound:path),+), - poll_read($self:ident, $cx:ident, $buf:ident) => $poll_read_expr:expr, - poll_write($self_w:ident, $cx_w:ident, $buf_w:ident) => $poll_write_expr:expr, - poll_flush($self_f:ident, $cx_f:ident) => $poll_flush_expr:expr, - poll_shutdown($self_s:ident, $cx_s:ident) => $poll_shutdown_expr:expr $(,)? - ) => { - $(#[$meta])* - #[cfg(feature = $feature)] - pub struct $name { - inner: S, - read_buf: Vec, - read_len: usize, - read_pos: usize, +// ─── Tokio adapter ─────────────────────────────────────────────────────────── + +/// Adapter that wraps a tokio [`AsyncRead`][tokio::io::AsyncRead] + +/// [`AsyncWrite`][tokio::io::AsyncWrite] into a [`Socket`] implementation. +#[cfg(feature = "_rt-tokio")] +pub struct TokioStream { + inner: S, + read_buf: Vec, + read_len: usize, + read_pos: usize, +} + +#[cfg(feature = "_rt-tokio")] +impl TokioStream { + pub fn new(inner: S) -> Self { + Self { + inner, + read_buf: vec![0u8; ADAPTER_BUF_SIZE], + read_len: 0, + read_pos: 0, } + } - #[cfg(feature = $feature)] - impl $name { - pub fn new(inner: S) -> Self { - Self { - inner, - read_buf: vec![0u8; ADAPTER_BUF_SIZE], - read_len: 0, - read_pos: 0, - } - } + fn buffered(&self) -> &[u8] { + &self.read_buf[self.read_pos..self.read_len] + } +} - fn buffered(&self) -> &[u8] { - &self.read_buf[self.read_pos..self.read_len] +#[cfg(feature = "_rt-tokio")] +impl Socket for TokioStream +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Sync + Unpin + 'static, +{ + fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result { + let buffered = self.buffered(); + if !buffered.is_empty() { + let to_copy = std::cmp::min(buffered.len(), buf.remaining_mut()); + buf.put_slice(&buffered[..to_copy]); + self.read_pos += to_copy; + if self.read_pos == self.read_len { + self.read_pos = 0; + self.read_len = 0; } + return Ok(to_copy); } + Err(io::Error::from(io::ErrorKind::WouldBlock)) + } - #[cfg(feature = $feature)] - impl Socket for $name - where - S: $($bound +)+ Send + Sync + Unpin + 'static, - { - fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result { - let buffered = self.buffered(); - if !buffered.is_empty() { - let to_copy = std::cmp::min(buffered.len(), buf.remaining_mut()); - buf.put_slice(&buffered[..to_copy]); - self.read_pos += to_copy; - if self.read_pos == self.read_len { - self.read_pos = 0; - self.read_len = 0; - } - return Ok(to_copy); - } - Err(io::Error::from(io::ErrorKind::WouldBlock)) - } + fn try_write(&mut self, buf: &[u8]) -> io::Result { + let waker = futures_util::task::noop_waker(); + let mut cx = Context::from_waker(&waker); + match Pin::new(&mut self.inner).poll_write(&mut cx, buf) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } + } + + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if !self.buffered().is_empty() { + return Poll::Ready(Ok(())); + } + + self.read_pos = 0; + self.read_len = 0; - fn try_write(&mut self, buf: &[u8]) -> io::Result { - // Safe to use noop_waker here: if Pending is returned, the caller - // (Socket::write future) will call poll_write_ready(cx) with the real - // task context, which re-registers the proper waker. - let waker = futures_util::task::noop_waker(); - let mut cx = Context::from_waker(&waker); - let $self_w = &mut self.inner; - let $buf_w = buf; - let $cx_w = &mut cx; - match $poll_write_expr { - Poll::Ready(result) => result, - Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + let mut read_buf = tokio::io::ReadBuf::new(&mut self.read_buf); + match Pin::new(&mut self.inner).poll_read(cx, &mut read_buf) { + Poll::Ready(Ok(())) => { + let n = read_buf.filled().len(); + if n == 0 { + return Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof))); } + self.read_len = n; + Poll::Ready(Ok(())) } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } - fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - if !self.buffered().is_empty() { - return Poll::Ready(Ok(())); - } + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match Pin::new(&mut self.inner).poll_write(cx, &[]) { + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } + + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} +// ─── Futures-io adapter ────────────────────────────────────────────────────── + +/// Adapter that wraps a futures-io [`AsyncRead`][futures_io::AsyncRead] + +/// [`AsyncWrite`][futures_io::AsyncWrite] into a [`Socket`] implementation. +#[cfg(feature = "_rt-async-io")] +pub struct FuturesStream { + inner: S, + read_buf: Vec, + read_len: usize, + read_pos: usize, +} + +#[cfg(feature = "_rt-async-io")] +impl FuturesStream { + pub fn new(inner: S) -> Self { + Self { + inner, + read_buf: vec![0u8; ADAPTER_BUF_SIZE], + read_len: 0, + read_pos: 0, + } + } + + fn buffered(&self) -> &[u8] { + &self.read_buf[self.read_pos..self.read_len] + } +} + +#[cfg(feature = "_rt-async-io")] +impl Socket for FuturesStream +where + S: futures_io::AsyncRead + futures_io::AsyncWrite + Send + Sync + Unpin + 'static, +{ + fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result { + let buffered = self.buffered(); + if !buffered.is_empty() { + let to_copy = std::cmp::min(buffered.len(), buf.remaining_mut()); + buf.put_slice(&buffered[..to_copy]); + self.read_pos += to_copy; + if self.read_pos == self.read_len { self.read_pos = 0; self.read_len = 0; - - let $cx = cx; - let $self = &mut self.inner; - let $buf = &mut self.read_buf; - match $poll_read_expr { - Poll::Ready(Ok(n)) => { - if n == 0 { - return Poll::Ready(Err(io::Error::from( - io::ErrorKind::UnexpectedEof, - ))); - } - self.read_len = n; - Poll::Ready(Ok(())) - } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, - } } + return Ok(to_copy); + } + Err(io::Error::from(io::ErrorKind::WouldBlock)) + } - fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - // Attempt a zero-byte write to check if the underlying stream is writable. - // This registers the real waker with the I/O resource so we get woken - // when the socket becomes writable. - let $cx_w = cx; - let $self_w = &mut self.inner; - let $buf_w: &[u8] = &[]; - match $poll_write_expr { - Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, - } - } + fn try_write(&mut self, buf: &[u8]) -> io::Result { + let waker = futures_util::task::noop_waker(); + let mut cx = Context::from_waker(&waker); + match Pin::new(&mut self.inner).poll_write(&mut cx, buf) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } + } - fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { - let $cx_f = cx; - let $self_f = &mut self.inner; - $poll_flush_expr - } + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if !self.buffered().is_empty() { + return Poll::Ready(Ok(())); + } + + self.read_pos = 0; + self.read_len = 0; - fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { - let $cx_s = cx; - let $self_s = &mut self.inner; - $poll_shutdown_expr + match Pin::new(&mut self.inner).poll_read(cx, &mut self.read_buf) { + Poll::Ready(Ok(n)) => { + if n == 0 { + return Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof))); + } + self.read_len = n; + Poll::Ready(Ok(())) } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, } - }; -} + } -impl_socket_adapter! { - /// Adapter that wraps a tokio [`AsyncRead`][tokio::io::AsyncRead] + - /// [`AsyncWrite`][tokio::io::AsyncWrite] into a [`Socket`] implementation. - TokioStream, - feature = "_rt-tokio", - bounds(tokio::io::AsyncRead, tokio::io::AsyncWrite), - poll_read(inner, cx, buf) => { - let mut read_buf = tokio::io::ReadBuf::new(buf); - match Pin::new(inner).poll_read(cx, &mut read_buf) { - Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())), + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match Pin::new(&mut self.inner).poll_write(cx, &[]) { + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), Poll::Ready(Err(e)) => Poll::Ready(Err(e)), Poll::Pending => Poll::Pending, } - }, - poll_write(inner, cx, buf) => { - Pin::new(inner).poll_write(cx, buf) - }, - poll_flush(inner, cx) => { - Pin::new(inner).poll_flush(cx) - }, - poll_shutdown(inner, cx) => { - Pin::new(inner).poll_shutdown(cx) - }, -} + } -impl_socket_adapter! { - /// Adapter that wraps a futures-io [`AsyncRead`][futures_io::AsyncRead] + - /// [`AsyncWrite`][futures_io::AsyncWrite] into a [`Socket`] implementation. - FuturesStream, - feature = "_rt-async-io", - bounds(futures_io::AsyncRead, futures_io::AsyncWrite), - poll_read(inner, cx, buf) => { - Pin::new(inner).poll_read(cx, buf) - }, - poll_write(inner, cx, buf) => { - Pin::new(inner).poll_write(cx, buf) - }, - poll_flush(inner, cx) => { - Pin::new(inner).poll_flush(cx) - }, - poll_shutdown(inner, cx) => { - Pin::new(inner).poll_close(cx) - }, + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_close(cx) + } } #[cfg(test)] @@ -370,7 +390,6 @@ mod tests { assert_eq!(&received[..], &data[..]); }); } - } #[cfg(feature = "_rt-async-io")] diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 126771565a..f1c35f8dcf 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -2218,3 +2218,123 @@ async fn it_can_recover_from_copy_in_invalid_params() -> anyhow::Result<()> { ) .await } + +#[cfg(feature = "_rt-tokio")] +#[sqlx_macros::test] +async fn it_connects_raw_tokio() -> anyhow::Result<()> { + setup_if_needed(); + + let db_url = env::var("DATABASE_URL")?; + let options: PgConnectOptions = db_url.parse()?; + + let stream = + tokio::net::TcpStream::connect(format!("{}:{}", options.get_host(), options.get_port())) + .await?; + + let mut conn = PgConnection::connect_raw_tokio(stream, &options).await?; + conn.ping().await?; + + let value: (i32,) = sqlx::query_as("SELECT 1 + 1").fetch_one(&mut conn).await?; + assert_eq!(value.0, 2); + + Ok(()) +} + +#[cfg(all( + feature = "_rt-tokio", + any( + feature = "tls-native-tls", + feature = "tls-rustls", + feature = "tls-rustls-aws-lc-rs", + feature = "tls-rustls-ring", + feature = "tls-rustls-ring-webpki", + feature = "tls-rustls-ring-native-roots", + ) +))] +#[sqlx_macros::test] +async fn it_connects_raw_tokio_with_tls() -> anyhow::Result<()> { + setup_if_needed(); + + let db_url = env::var("DATABASE_URL")?; + let options: PgConnectOptions = db_url + .parse::()? + .ssl_mode(sqlx::postgres::PgSslMode::Require); + + let stream = + tokio::net::TcpStream::connect(format!("{}:{}", options.get_host(), options.get_port())) + .await?; + + let mut conn = PgConnection::connect_raw_tokio(stream, &options).await?; + conn.ping().await?; + + // Verify TLS is actually in use by checking the connection's SSL status + let ssl: bool = sqlx::query_scalar("SELECT ssl FROM pg_stat_ssl WHERE pid = pg_backend_pid()") + .fetch_one(&mut conn) + .await?; + assert!(ssl, "expected connection to be using TLS"); + + Ok(()) +} + +#[cfg(feature = "_rt-async-std")] +#[sqlx_macros::test] +async fn it_connects_raw_futures() -> anyhow::Result<()> { + setup_if_needed(); + + let db_url = env::var("DATABASE_URL")?; + let options: PgConnectOptions = db_url.parse()?; + + let stream = async_std::net::TcpStream::connect(format!( + "{}:{}", + options.get_host(), + options.get_port() + )) + .await?; + + let mut conn = PgConnection::connect_raw_futures(stream, &options).await?; + conn.ping().await?; + + let value: (i32,) = sqlx::query_as("SELECT 1 + 1").fetch_one(&mut conn).await?; + assert_eq!(value.0, 2); + + Ok(()) +} + +#[cfg(all( + feature = "_rt-async-std", + any( + feature = "tls-native-tls", + feature = "tls-rustls", + feature = "tls-rustls-aws-lc-rs", + feature = "tls-rustls-ring", + feature = "tls-rustls-ring-webpki", + feature = "tls-rustls-ring-native-roots", + ) +))] +#[sqlx_macros::test] +async fn it_connects_raw_futures_with_tls() -> anyhow::Result<()> { + setup_if_needed(); + + let db_url = env::var("DATABASE_URL")?; + let options: PgConnectOptions = db_url + .parse::()? + .ssl_mode(sqlx::postgres::PgSslMode::Require); + + let stream = async_std::net::TcpStream::connect(format!( + "{}:{}", + options.get_host(), + options.get_port() + )) + .await?; + + let mut conn = PgConnection::connect_raw_futures(stream, &options).await?; + conn.ping().await?; + + // Verify TLS is actually in use by checking the connection's SSL status + let ssl: bool = sqlx::query_scalar("SELECT ssl FROM pg_stat_ssl WHERE pid = pg_backend_pid()") + .fetch_one(&mut conn) + .await?; + assert!(ssl, "expected connection to be using TLS"); + + Ok(()) +}