Skip to content

Commit 06a9a54

Browse files
committed
fix(proxy): address PR review feedback for connection error handling
Replace Notify + fixed sleep in advisory lock test with pg_locks polling for deterministic synchronization. Send ErrorResponse to clients for pre-split connection timeouts instead of silent disconnect.
1 parent 48e9a62 commit 06a9a54

2 files changed

Lines changed: 45 additions & 13 deletions

File tree

packages/cipherstash-proxy-integration/src/connection_resilience.rs

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
#[cfg(test)]
99
mod tests {
1010
use crate::common::{connect_with_tls, PG_PORT, PROXY};
11-
use std::sync::Arc;
1211
use std::time::Instant;
13-
use tokio::sync::Notify;
1412
use tokio::task::JoinSet;
1513
use tokio::time::{timeout, Duration};
14+
use tokio_postgres::SimpleQueryMessage;
1615

1716
/// Advisory lock ID used in isolation tests. Arbitrary value — just needs to be
1817
/// unique across concurrently running test suites against the same database.
@@ -120,9 +119,8 @@ mod tests {
120119

121120
/// An advisory-lock-blocked connection through the proxy does not block other proxy connections.
122121
///
123-
/// Note: Connection B notifies readiness before `pg_advisory_lock` reaches PostgreSQL.
124-
/// The 500ms sleep provides a generous margin for the lock attempt to reach PG, but is
125-
/// not strictly guaranteed. In practice this has not caused flakiness.
122+
/// Uses pg_locks polling to deterministically wait for client_b to be blocked on the
123+
/// advisory lock, rather than relying on a fixed sleep.
126124
#[tokio::test]
127125
async fn advisory_lock_blocked_connection_does_not_block_proxy() {
128126
let lock_query = format!("SELECT pg_advisory_lock({ADVISORY_LOCK_ID})");
@@ -136,15 +134,12 @@ mod tests {
136134
.await
137135
.unwrap();
138136

139-
let a_ready = Arc::new(Notify::new());
140-
let a_ready_tx = a_ready.clone();
141137
let b_lock_query = lock_query.clone();
142138
let b_unlock_query = unlock_query.clone();
143139

144140
// Connection B: through proxy, attempt to acquire the same lock (will block)
145141
let b_handle = tokio::spawn(async move {
146142
let client_b = connect_with_tls(PROXY).await;
147-
a_ready_tx.notify_one();
148143
// This will block until A releases the lock
149144
client_b
150145
.simple_query(&b_lock_query)
@@ -157,9 +152,23 @@ mod tests {
157152
.unwrap();
158153
});
159154

160-
// Wait for B to be connected and attempting the lock
161-
a_ready.notified().await;
162-
tokio::time::sleep(Duration::from_millis(500)).await;
155+
// Poll pg_locks until client_b is observed waiting for the advisory lock
156+
let poll_query = format!(
157+
"SELECT 1 FROM pg_locks WHERE locktype = 'advisory' AND NOT granted AND classid = 0 AND objid = {ADVISORY_LOCK_ID}"
158+
);
159+
let deadline = Instant::now() + Duration::from_secs(10);
160+
loop {
161+
let result = client_a.simple_query(&poll_query).await.unwrap();
162+
let has_waiting = result.iter().any(|m| matches!(m, SimpleQueryMessage::Row(_)));
163+
if has_waiting {
164+
break;
165+
}
166+
assert!(
167+
Instant::now() < deadline,
168+
"Timed out waiting for client_b to be blocked on advisory lock"
169+
);
170+
tokio::time::sleep(Duration::from_millis(50)).await;
171+
}
163172

164173
// Connection C: through proxy, should complete immediately despite B being blocked
165174
let start = Instant::now();

packages/cipherstash-proxy/src/postgresql/handler.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,14 @@ pub async fn handler(client_stream: AsyncStream, context: Context<ZeroKms>) -> R
6767

6868
loop {
6969
let startup_message =
70-
startup::read_message(&mut client_stream, context.connection_timeout()).await?;
70+
match startup::read_message(&mut client_stream, context.connection_timeout()).await {
71+
Ok(msg) => msg,
72+
Err(err @ Error::ConnectionTimeout { .. }) => {
73+
send_timeout_error(&mut client_stream, &err).await;
74+
return Err(err);
75+
}
76+
Err(err) => return Err(err),
77+
};
7178

7279
match &startup_message.code {
7380
StartupCode::SSLRequest => {
@@ -119,7 +126,14 @@ pub async fn handler(client_stream: AsyncStream, context: Context<ZeroKms>) -> R
119126

120127
let connection_timeout = context.connection_timeout();
121128
let (_code, bytes) =
122-
protocol::read_message(&mut client_stream, client_id, connection_timeout).await?;
129+
match protocol::read_message(&mut client_stream, client_id, connection_timeout).await {
130+
Ok(result) => result,
131+
Err(err @ Error::ConnectionTimeout { .. }) => {
132+
send_timeout_error(&mut client_stream, &err).await;
133+
return Err(err);
134+
}
135+
Err(err) => return Err(err),
136+
};
123137

124138
let password_message = PasswordMessage::try_from(&bytes)?;
125139

@@ -369,3 +383,12 @@ async fn scram_sha_256_plus_handler<S: AsyncRead + AsyncWrite + Unpin>(
369383
Err(ProtocolError::AuthenticationFailed.into())
370384
}
371385
}
386+
387+
/// Best-effort send of a connection timeout ErrorResponse directly to a client stream.
388+
/// Used for pre-split timeout sites where no ChannelWriter exists yet.
389+
async fn send_timeout_error<S: AsyncWrite + Unpin>(stream: &mut S, err: &Error) {
390+
let error_response = ErrorResponse::connection_timeout(err.to_string());
391+
if let Ok(bytes) = BytesMut::try_from(error_response) {
392+
let _ = stream.write_all(&bytes).await;
393+
}
394+
}

0 commit comments

Comments
 (0)