diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c4de7b..1b6208d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/). --- +## [0.3.2] - 2026-05-08 + +### Fixed +- **Cancellation-safe message reading** in `stream_loop` (issue #5) + - The wait-phase `tokio::select!` previously raced `stop_rx.changed()` and `tokio::time::timeout` against `read_backend_message_into`, whose internal `read_exact` calls are not cancellation-safe. If the timeout fired while a backend message was mid-flight, partially-read header/payload bytes were dropped along with the cancelled future, leaving the next iteration to mis-parse the wire stream - typically surfacing as a bogus `payload_len` followed by a hang or `Protocol` error + - New `MessageReader` externalizes partial-read state (header/payload counters, parsed tag/length) and uses one-shot `AsyncReadExt::read` so dropped futures never lose bytes; the next call resumes exactly where the previous one left off + - `stream_loop` now owns a single `MessageReader` reused across drain and wait phases + +### Notes +- `read_backend_message_into` is retained for non-`select!` callers (startup, auth, replication-start) and is now documented as **not cancellation-safe** + +--- + ## [0.3.1] - 2026-03-28 ### Improved @@ -100,7 +113,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/). - Fuzz testing for pgwire framing -[Unreleased]: https://github.com/vnvo/pgwire-replication/compare/v0.3.0...HEAD +[Unreleased]: https://github.com/vnvo/pgwire-replication/compare/v0.3.2...HEAD +[0.3.2]: https://github.com/vnvo/pgwire-replication/compare/v0.3.1...v0.3.2 +[0.3.1]: https://github.com/vnvo/pgwire-replication/compare/v0.3.0...v0.3.1 [0.3.0]: https://github.com/vnvo/pgwire-replication/compare/v0.2.0...v0.3.0 [0.2.0]: https://github.com/vnvo/pgwire-replication/releases/tag/v0.2.0 [0.1.2]: https://github.com/vnvo/pgwire-replication/releases/tag/v0.1.2 diff --git a/Cargo.toml b/Cargo.toml index c12c37e..b00eb63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgwire-replication" -version = "0.3.1" +version = "0.3.2" edition = "2021" rust-version = "1.88" resolver = "2" diff --git a/benches/protocol_bench.rs b/benches/protocol_bench.rs index 7c27b1f..e7692dd 100644 --- a/benches/protocol_bench.rs +++ b/benches/protocol_bench.rs @@ -2,10 +2,12 @@ //! //! Run with: `cargo bench --bench protocol_bench` -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use std::io::Cursor; use pgwire_replication::lsn::Lsn; +use pgwire_replication::protocol::framing::{read_backend_message_into, MessageReader}; use pgwire_replication::protocol::messages::{parse_error_response, ErrorFields}; use pgwire_replication::protocol::replication::{encode_standby_status_update, parse_copy_data}; @@ -96,6 +98,76 @@ fn bench_error_fields_parse(c: &mut Criterion) { }); } +/// Build a buffer of N back-to-back CopyData messages of `data_size` payload. +fn make_copy_data_stream(count: usize, data_size: usize) -> Vec { + let payload = make_xlogdata_payload(data_size); + let frame_len = (4 + payload.len()) as i32; + + let mut buf = Vec::with_capacity(count * (5 + payload.len())); + for _ in 0..count { + buf.push(b'd'); // CopyData tag + buf.extend_from_slice(&frame_len.to_be_bytes()); + buf.extend_from_slice(&payload); + } + buf +} + +/// Compare the throughput of the legacy `read_backend_message_into` (not +/// cancellation-safe) against the new `MessageReader::read` (cancellation-safe). +/// Both are driven against an in-memory `Cursor>` so the benchmark +/// isolates framing overhead from socket / scheduler effects. +fn bench_read_backend_message(c: &mut Criterion) { + const COUNT: usize = 256; + let mut group = c.benchmark_group("read_backend_message"); + + for size in [64, 256, 1024, 4096] { + let stream = make_copy_data_stream(COUNT, size); + group.throughput(Throughput::Bytes(stream.len() as u64)); + + // Legacy non-cancel-safe path (kept for compatibility). + group.bench_with_input( + BenchmarkId::new("read_backend_message_into", size), + &stream, + |b, stream| { + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .unwrap(); + b.iter(|| { + rt.block_on(async { + let mut cur = Cursor::new(black_box(stream.as_slice())); + let mut buf = BytesMut::with_capacity(4096); + for _ in 0..COUNT { + let _msg = read_backend_message_into(&mut cur, &mut buf).await.unwrap(); + } + }); + }); + }, + ); + + // New cancellation-safe path used by the streaming loop. + group.bench_with_input( + BenchmarkId::new("MessageReader", size), + &stream, + |b, stream| { + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .unwrap(); + b.iter(|| { + rt.block_on(async { + let mut cur = Cursor::new(black_box(stream.as_slice())); + let mut reader = MessageReader::new(); + for _ in 0..COUNT { + let _msg = reader.read(&mut cur).await.unwrap(); + } + }); + }); + }, + ); + } + + group.finish(); +} + criterion_group!( benches, bench_parse_xlogdata, @@ -103,5 +175,6 @@ criterion_group!( bench_encode_status_update, bench_parse_error_response, bench_error_fields_parse, + bench_read_backend_message, ); criterion_main!(benches); diff --git a/src/client/worker.rs b/src/client/worker.rs index a41b3ad..37bb654 100644 --- a/src/client/worker.rs +++ b/src/client/worker.rs @@ -1,4 +1,4 @@ -use bytes::{Bytes, BytesMut}; +use bytes::Bytes; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite, BufReader}; @@ -9,8 +9,8 @@ use crate::config::ReplicationConfig; use crate::error::{PgWireError, Result}; use crate::lsn::Lsn; use crate::protocol::framing::{ - read_backend_message, read_backend_message_into, write_copy_data, write_copy_done, - write_password_message, write_query, write_startup_message, + read_backend_message, write_copy_data, write_copy_done, write_password_message, write_query, + write_startup_message, MessageReader, }; use crate::protocol::messages::{parse_auth_request, parse_error_response}; use crate::protocol::replication::{ @@ -208,14 +208,17 @@ impl WorkerState { /// in a tight loop without `select!` or timeout overhead. /// 2. **Wait phase**: when the buffer is empty, fall back to `select!` with /// timeout + stop signal to handle idle keepalives and graceful shutdown. + /// + /// Reads use [`MessageReader`], which preserves partial-read state across + /// dropped futures so the wait-phase `select!` is cancellation-safe. async fn stream_loop( &mut self, stream: &mut BufReader, ) -> Result<()> { let mut last_status_sent = Instant::now() - self.cfg.status_interval; let mut last_applied = self.progress.load_applied(); - // Reusable read buffer — avoids per-message allocation. - let mut read_buf = BytesMut::with_capacity(4096); + // Cancellation-safe message reader, partial reads survive dropped futures. + let mut reader = MessageReader::new(); // How many messages to process in the tight loop before checking // stop signal and sending periodic status feedback. const DRAIN_BATCH: usize = 256; @@ -239,24 +242,22 @@ impl WorkerState { // Read them in a tight loop to avoid select!/timeout overhead per message. let mut drained = 0usize; while stream.buffer().len() >= 5 && drained < DRAIN_BATCH { - let msg = read_backend_message_into(stream, &mut read_buf).await?; + let msg = reader.read(stream).await?; drained += 1; - match msg.tag { - b'd' => { - if self - .handle_copy_data( - stream, - msg.payload, - &mut last_applied, - &mut last_status_sent, - ) - .await? - { - return Ok(()); - } - } - b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))), - _ => {} + if msg.tag == b'E' { + return Err(PgWireError::Server(parse_error_response(&msg.payload))); + } + if msg.tag == b'd' + && self + .handle_copy_data( + stream, + msg.payload, + &mut last_applied, + &mut last_status_sent, + ) + .await? + { + return Ok(()); } } @@ -272,6 +273,11 @@ impl WorkerState { } // ── Wait phase: buffer empty, need to wait for socket data ── + // + // Both `stop_rx.changed()` and the timeout can drop the read future + // mid-message. `MessageReader::read` is cancellation-safe — partial + // header/payload state lives on `reader` and is preserved across the + // drop, so the next iteration resumes the read without losing bytes. let msg = tokio::select! { biased; @@ -285,7 +291,7 @@ impl WorkerState { msg_result = tokio::time::timeout( self.cfg.idle_wakeup_interval, - read_backend_message_into(stream, &mut read_buf), + reader.read(stream), ) => { match msg_result { Ok(res) => res?, @@ -300,22 +306,20 @@ impl WorkerState { } }; - match msg.tag { - b'd' => { - if self - .handle_copy_data( - stream, - msg.payload, - &mut last_applied, - &mut last_status_sent, - ) - .await? - { - return Ok(()); - } - } - b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))), - _ => {} + if msg.tag == b'E' { + return Err(PgWireError::Server(parse_error_response(&msg.payload))); + } + if msg.tag == b'd' + && self + .handle_copy_data( + stream, + msg.payload, + &mut last_applied, + &mut last_status_sent, + ) + .await? + { + return Ok(()); } } } diff --git a/src/protocol/framing.rs b/src/protocol/framing.rs index e0c08f4..98f21a5 100644 --- a/src/protocol/framing.rs +++ b/src/protocol/framing.rs @@ -1,4 +1,5 @@ use bytes::{BufMut, Bytes, BytesMut}; +use std::io; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use crate::error::{PgWireError, Result}; @@ -46,10 +47,131 @@ impl BackendMessage { } pub async fn read_backend_message(rd: &mut R) -> Result { - read_backend_message_into(rd, &mut BytesMut::new()).await + let mut reader = MessageReader::new(); + reader.read(rd).await } -/// Read a backend message, reusing `buf` to avoid per-message allocation. +/// Cancellation-safe backend message reader. +/// +/// PostgreSQL backend messages span multiple `read` operations (5-byte header, +/// then a variable payload). A naive implementation using `read_exact` is +/// **not** cancellation-safe: if the future is dropped between reads (e.g. by +/// `tokio::select!` or `tokio::time::timeout`), bytes already pulled from the +/// underlying stream are lost and the next read mis-parses the wire stream. +/// +/// `MessageReader` externalizes the partial-read state so it survives across +/// dropped futures. Each call to [`read`](Self::read) uses one-shot +/// `AsyncReadExt::read` (which **is** cancel-safe) and accumulates progress +/// on `self`. If the returned future is dropped, no bytes are lost; the next +/// invocation resumes from where the previous one left off. +pub struct MessageReader { + hdr: [u8; 5], + hdr_filled: usize, + payload: BytesMut, + payload_filled: usize, + /// `Some` once the header has been fully read and parsed; reset to + /// `None` after each completed message. + payload_len: Option, + tag: u8, +} + +impl MessageReader { + pub fn new() -> Self { + Self::with_capacity(4096) + } + + pub fn with_capacity(capacity: usize) -> Self { + Self { + hdr: [0u8; 5], + hdr_filled: 0, + payload: BytesMut::with_capacity(capacity), + payload_filled: 0, + payload_len: None, + tag: 0, + } + } + + /// Read the next complete backend message. + /// + /// Cancellation-safe: dropping the returned future preserves all progress + /// so far on `self`. Re-call to resume. + pub async fn read(&mut self, rd: &mut R) -> Result { + // Phase 1: fill the 5-byte header + while self.hdr_filled < 5 { + let n = rd.read(&mut self.hdr[self.hdr_filled..]).await?; + if n == 0 { + return Err(PgWireError::Io(std::sync::Arc::new(io::Error::new( + io::ErrorKind::UnexpectedEof, + "EOF while reading backend message header", + )))); + } + self.hdr_filled += n; + } + + // Phase 2: parse the header (idempotent — runs once per message) + if self.payload_len.is_none() { + let len = i32::from_be_bytes([self.hdr[1], self.hdr[2], self.hdr[3], self.hdr[4]]); + + if len < 4 { + // Reset so the reader is reusable after a protocol error is + // surfaced (callers typically tear down on this anyway). + self.hdr_filled = 0; + return Err(PgWireError::Protocol(format!( + "invalid backend message length: {len}" + ))); + } + + let payload_len = (len - 4) as usize; + + if payload_len > MAX_MESSAGE_SIZE { + self.hdr_filled = 0; + return Err(PgWireError::Protocol(format!( + "backend message too large: {payload_len} bytes (max {MAX_MESSAGE_SIZE})" + ))); + } + + self.tag = self.hdr[0]; + self.payload.clear(); + self.payload.resize(payload_len, 0); + self.payload_filled = 0; + self.payload_len = Some(payload_len); + } + + let payload_len = self.payload_len.unwrap(); + + // Phase 3: fill the payload + while self.payload_filled < payload_len { + let n = rd.read(&mut self.payload[self.payload_filled..]).await?; + if n == 0 { + return Err(PgWireError::Io(std::sync::Arc::new(io::Error::new( + io::ErrorKind::UnexpectedEof, + "EOF while reading backend message payload", + )))); + } + self.payload_filled += n; + } + + // Phase 4: take payload, reset state for next message + let payload = self.payload.split().freeze(); + let tag = self.tag; + self.hdr_filled = 0; + self.payload_len = None; + self.payload_filled = 0; + + Ok(BackendMessage { tag, payload }) + } +} + +impl Default for MessageReader { + fn default() -> Self { + Self::new() + } +} + +/// Read a single backend message, reusing the provided buffer. +/// +/// **Not** cancellation-safe — see [`MessageReader`] for a cancel-safe +/// alternative used in the streaming loop. pub async fn read_backend_message_into( rd: &mut R, buf: &mut BytesMut, @@ -175,6 +297,7 @@ pub async fn write_copy_done(wr: &mut W) -> Result<()> { mod tests { use super::*; use std::io::Cursor; + use tokio::io::AsyncWriteExt; #[tokio::test] async fn read_backend_message_parses_valid_message() { @@ -209,6 +332,118 @@ mod tests { assert!(err.to_string().contains("invalid backend message length")); } + #[tokio::test] + async fn message_reader_reads_complete_message() { + // Tag 'Z' (ReadyForQuery), length=5 (4 + 1 byte payload), payload='I' + let data = [b'Z', 0, 0, 0, 5, b'I']; + let mut cursor = Cursor::new(&data[..]); + + let mut reader = MessageReader::new(); + let msg = reader.read(&mut cursor).await.unwrap(); + assert_eq!(msg.tag, b'Z'); + assert_eq!(&msg.payload[..], b"I"); + } + + #[tokio::test] + async fn message_reader_reads_back_to_back_messages() { + // Two messages on one stream: ReadyForQuery + NoticeResponse w/ empty payload + let data = [b'Z', 0, 0, 0, 5, b'I', b'N', 0, 0, 0, 4]; + let mut cursor = Cursor::new(&data[..]); + + let mut reader = MessageReader::new(); + + let m1 = reader.read(&mut cursor).await.unwrap(); + assert_eq!(m1.tag, b'Z'); + assert_eq!(&m1.payload[..], b"I"); + + let m2 = reader.read(&mut cursor).await.unwrap(); + assert_eq!(m2.tag, b'N'); + assert!(m2.payload.is_empty()); + } + + /// Regression test for issue #5: reading a backend message must be + /// cancellation-safe so that `tokio::select!` / `tokio::time::timeout` + /// dropping the read future mid-message does not corrupt the stream. + /// + /// With the old `read_backend_message_into`, dropping the future after + /// 3 of 5 header bytes were consumed would lose those 3 bytes and + /// re-parse the next bytes as a new header, producing a bogus length + /// and a Protocol error (or worse, a silent desync). + #[tokio::test] + async fn message_reader_resumes_after_cancellation_mid_header() { + let (mut writer, mut rd) = tokio::io::duplex(64); + let mut reader = MessageReader::new(); + + // Tag 'd' (CopyData), length = 8 (4 + 4-byte payload), payload b"abcd" + let header = [b'd', 0, 0, 0, 8]; + let payload = b"abcd"; + + // Deliver only the first 3 header bytes, then cancel. + writer.write_all(&header[..3]).await.unwrap(); + + let timed_out = + tokio::time::timeout(std::time::Duration::from_millis(20), reader.read(&mut rd)).await; + assert!( + timed_out.is_err(), + "read must time out while waiting for remaining header bytes" + ); + + // Deliver the remaining bytes. A correct cancel-safe reader resumes + // and returns the original message intact. + writer.write_all(&header[3..]).await.unwrap(); + writer.write_all(payload).await.unwrap(); + + let msg = reader.read(&mut rd).await.unwrap(); + assert_eq!(msg.tag, b'd'); + assert_eq!(&msg.payload[..], payload); + } + + /// Ensures partial-payload cancellation also resumes correctly. + #[tokio::test] + async fn message_reader_resumes_after_cancellation_mid_payload() { + let (mut writer, mut rd) = tokio::io::duplex(64); + let mut reader = MessageReader::new(); + + // 16-byte payload to ensure we can split it. + let payload: [u8; 16] = std::array::from_fn(|i| i as u8); + let len = (4 + payload.len()) as i32; + let header = [ + b'd', + (len >> 24) as u8, + (len >> 16) as u8, + (len >> 8) as u8, + len as u8, + ]; + + // Full header + first 5 bytes of payload, then cancel. + writer.write_all(&header).await.unwrap(); + writer.write_all(&payload[..5]).await.unwrap(); + + let timed_out = + tokio::time::timeout(std::time::Duration::from_millis(20), reader.read(&mut rd)).await; + assert!( + timed_out.is_err(), + "read must time out while waiting for remaining payload bytes" + ); + + // Deliver the rest. + writer.write_all(&payload[5..]).await.unwrap(); + + let msg = reader.read(&mut rd).await.unwrap(); + assert_eq!(msg.tag, b'd'); + assert_eq!(&msg.payload[..], &payload[..]); + } + + #[tokio::test] + async fn message_reader_rejects_invalid_length() { + let data = [b'Z', 0, 0, 0, 3]; + let mut cursor = Cursor::new(&data[..]); + + let mut reader = MessageReader::new(); + let err = reader.read(&mut cursor).await.unwrap_err(); + assert!(err.to_string().contains("invalid backend message length")); + } + #[tokio::test] async fn read_backend_message_rejects_oversized_message() { // length = MAX_MESSAGE_SIZE + 5 (over limit)