Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/).

---

## [0.3.2] - 2026-05-08

### Fixed
- **Cancellation-safe message reading** in `stream_loop` (issue #5)
- The wait-phase `tokio::select!` previously raced `stop_rx.changed()` and `tokio::time::timeout` against `read_backend_message_into`, whose internal `read_exact` calls are not cancellation-safe. If the timeout fired while a backend message was mid-flight, partially-read header/payload bytes were dropped along with the cancelled future, leaving the next iteration to mis-parse the wire stream - typically surfacing as a bogus `payload_len` followed by a hang or `Protocol` error
- New `MessageReader` externalizes partial-read state (header/payload counters, parsed tag/length) and uses one-shot `AsyncReadExt::read` so dropped futures never lose bytes; the next call resumes exactly where the previous one left off
- `stream_loop` now owns a single `MessageReader` reused across drain and wait phases

### Notes
- `read_backend_message_into` is retained for non-`select!` callers (startup, auth, replication-start) and is now documented as **not cancellation-safe**

---

## [0.3.1] - 2026-03-28

### Improved
Expand Down Expand Up @@ -100,7 +113,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
- Fuzz testing for pgwire framing


[Unreleased]: https://github.com/vnvo/pgwire-replication/compare/v0.3.0...HEAD
[Unreleased]: https://github.com/vnvo/pgwire-replication/compare/v0.3.2...HEAD
[0.3.2]: https://github.com/vnvo/pgwire-replication/compare/v0.3.1...v0.3.2
[0.3.1]: https://github.com/vnvo/pgwire-replication/compare/v0.3.0...v0.3.1
[0.3.0]: https://github.com/vnvo/pgwire-replication/compare/v0.2.0...v0.3.0
[0.2.0]: https://github.com/vnvo/pgwire-replication/releases/tag/v0.2.0
[0.1.2]: https://github.com/vnvo/pgwire-replication/releases/tag/v0.1.2
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pgwire-replication"
version = "0.3.1"
version = "0.3.2"
edition = "2021"
rust-version = "1.88"
resolver = "2"
Expand Down
75 changes: 74 additions & 1 deletion benches/protocol_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
//!
//! Run with: `cargo bench --bench protocol_bench`

use bytes::Bytes;
use bytes::{Bytes, BytesMut};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use std::io::Cursor;

use pgwire_replication::lsn::Lsn;
use pgwire_replication::protocol::framing::{read_backend_message_into, MessageReader};
use pgwire_replication::protocol::messages::{parse_error_response, ErrorFields};
use pgwire_replication::protocol::replication::{encode_standby_status_update, parse_copy_data};

Expand Down Expand Up @@ -96,12 +98,83 @@ fn bench_error_fields_parse(c: &mut Criterion) {
});
}

/// Build a buffer of N back-to-back CopyData messages of `data_size` payload.
fn make_copy_data_stream(count: usize, data_size: usize) -> Vec<u8> {
let payload = make_xlogdata_payload(data_size);
let frame_len = (4 + payload.len()) as i32;

let mut buf = Vec::with_capacity(count * (5 + payload.len()));
for _ in 0..count {
buf.push(b'd'); // CopyData tag
buf.extend_from_slice(&frame_len.to_be_bytes());
buf.extend_from_slice(&payload);
}
buf
}

/// Compare the throughput of the legacy `read_backend_message_into` (not
/// cancellation-safe) against the new `MessageReader::read` (cancellation-safe).
/// Both are driven against an in-memory `Cursor<Vec<u8>>` so the benchmark
/// isolates framing overhead from socket / scheduler effects.
fn bench_read_backend_message(c: &mut Criterion) {
const COUNT: usize = 256;
let mut group = c.benchmark_group("read_backend_message");

for size in [64, 256, 1024, 4096] {
let stream = make_copy_data_stream(COUNT, size);
group.throughput(Throughput::Bytes(stream.len() as u64));

// Legacy non-cancel-safe path (kept for compatibility).
group.bench_with_input(
BenchmarkId::new("read_backend_message_into", size),
&stream,
|b, stream| {
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
b.iter(|| {
rt.block_on(async {
let mut cur = Cursor::new(black_box(stream.as_slice()));
let mut buf = BytesMut::with_capacity(4096);
for _ in 0..COUNT {
let _msg = read_backend_message_into(&mut cur, &mut buf).await.unwrap();
}
});
});
},
);

// New cancellation-safe path used by the streaming loop.
group.bench_with_input(
BenchmarkId::new("MessageReader", size),
&stream,
|b, stream| {
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
b.iter(|| {
rt.block_on(async {
let mut cur = Cursor::new(black_box(stream.as_slice()));
let mut reader = MessageReader::new();
for _ in 0..COUNT {
let _msg = reader.read(&mut cur).await.unwrap();
}
});
});
},
);
}

group.finish();
}

criterion_group!(
benches,
bench_parse_xlogdata,
bench_parse_keepalive,
bench_encode_status_update,
bench_parse_error_response,
bench_error_fields_parse,
bench_read_backend_message,
);
criterion_main!(benches);
82 changes: 43 additions & 39 deletions src/client/worker.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use bytes::{Bytes, BytesMut};
use bytes::Bytes;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, BufReader};
Expand All @@ -9,8 +9,8 @@ use crate::config::ReplicationConfig;
use crate::error::{PgWireError, Result};
use crate::lsn::Lsn;
use crate::protocol::framing::{
read_backend_message, read_backend_message_into, write_copy_data, write_copy_done,
write_password_message, write_query, write_startup_message,
read_backend_message, write_copy_data, write_copy_done, write_password_message, write_query,
write_startup_message, MessageReader,
};
use crate::protocol::messages::{parse_auth_request, parse_error_response};
use crate::protocol::replication::{
Expand Down Expand Up @@ -208,14 +208,17 @@ impl WorkerState {
/// in a tight loop without `select!` or timeout overhead.
/// 2. **Wait phase**: when the buffer is empty, fall back to `select!` with
/// timeout + stop signal to handle idle keepalives and graceful shutdown.
///
/// Reads use [`MessageReader`], which preserves partial-read state across
/// dropped futures so the wait-phase `select!` is cancellation-safe.
async fn stream_loop<S: AsyncRead + AsyncWrite + Unpin>(
&mut self,
stream: &mut BufReader<S>,
) -> Result<()> {
let mut last_status_sent = Instant::now() - self.cfg.status_interval;
let mut last_applied = self.progress.load_applied();
// Reusable read buffer — avoids per-message allocation.
let mut read_buf = BytesMut::with_capacity(4096);
// Cancellation-safe message reader, partial reads survive dropped futures.
let mut reader = MessageReader::new();
// How many messages to process in the tight loop before checking
// stop signal and sending periodic status feedback.
const DRAIN_BATCH: usize = 256;
Expand All @@ -239,24 +242,22 @@ impl WorkerState {
// Read them in a tight loop to avoid select!/timeout overhead per message.
let mut drained = 0usize;
while stream.buffer().len() >= 5 && drained < DRAIN_BATCH {
let msg = read_backend_message_into(stream, &mut read_buf).await?;
let msg = reader.read(stream).await?;
drained += 1;
match msg.tag {
b'd' => {
if self
.handle_copy_data(
stream,
msg.payload,
&mut last_applied,
&mut last_status_sent,
)
.await?
{
return Ok(());
}
}
b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
_ => {}
if msg.tag == b'E' {
return Err(PgWireError::Server(parse_error_response(&msg.payload)));
}
if msg.tag == b'd'
&& self
.handle_copy_data(
stream,
msg.payload,
&mut last_applied,
&mut last_status_sent,
)
.await?
{
return Ok(());
}
}

Expand All @@ -272,6 +273,11 @@ impl WorkerState {
}

// ── Wait phase: buffer empty, need to wait for socket data ──
//
// Both `stop_rx.changed()` and the timeout can drop the read future
// mid-message. `MessageReader::read` is cancellation-safe — partial
// header/payload state lives on `reader` and is preserved across the
// drop, so the next iteration resumes the read without losing bytes.
let msg = tokio::select! {
biased;

Expand All @@ -285,7 +291,7 @@ impl WorkerState {

msg_result = tokio::time::timeout(
self.cfg.idle_wakeup_interval,
read_backend_message_into(stream, &mut read_buf),
reader.read(stream),
) => {
match msg_result {
Ok(res) => res?,
Expand All @@ -300,22 +306,20 @@ impl WorkerState {
}
};

match msg.tag {
b'd' => {
if self
.handle_copy_data(
stream,
msg.payload,
&mut last_applied,
&mut last_status_sent,
)
.await?
{
return Ok(());
}
}
b'E' => return Err(PgWireError::Server(parse_error_response(&msg.payload))),
_ => {}
if msg.tag == b'E' {
return Err(PgWireError::Server(parse_error_response(&msg.payload)));
}
if msg.tag == b'd'
&& self
.handle_copy_data(
stream,
msg.payload,
&mut last_applied,
&mut last_status_sent,
)
.await?
{
return Ok(());
}
}
}
Expand Down
Loading
Loading