Skip to content

Commit d45a353

Browse files
committed
Fix re-connection bug
1 parent 86eabba commit d45a353

2 files changed

Lines changed: 79 additions & 63 deletions

File tree

src/http_version.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! HTTP Version support and negotiation
2+
use bytes::Bytes;
23
use hyper::Response;
34
use hyper_util::rt::TokioIo;
4-
use bytes::Bytes;
55
use std::pin::Pin;
66
use std::task::{Context, Poll};
77

@@ -59,11 +59,10 @@ impl HttpVersion {
5959
type Http1Sender = hyper::client::conn::http1::SendRequest<http_body_util::Full<Bytes>>;
6060
type Http2Sender = hyper::client::conn::http2::SendRequest<http_body_util::Full<Bytes>>;
6161

62-
type Http1Connection =
63-
hyper::client::conn::http1::Connection<
64-
TokioIo<ProxyClientTlsStream>,
65-
http_body_util::Full<Bytes>,
66-
>;
62+
type Http1Connection = hyper::client::conn::http1::Connection<
63+
TokioIo<ProxyClientTlsStream>,
64+
http_body_util::Full<Bytes>,
65+
>;
6766

6867
type Http2Connection = hyper::client::conn::http2::Connection<
6968
TokioIo<ProxyClientTlsStream>,

src/lib.rs

Lines changed: 74 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -410,11 +410,20 @@ impl ProxyServer {
410410
let (_io, server_connection) = tls_stream.get_ref();
411411

412412
match server_connection.peer_certificates() {
413-
Some(remote_cert_chain) => Some(
414-
AttestedCertificateVerifier::extract_custom_attestation_from_cert(
415-
remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?,
416-
)?,
417-
),
413+
Some(remote_cert_chain) => remote_cert_chain
414+
.first()
415+
.and_then(|cert| {
416+
match AttestedCertificateVerifier::extract_custom_attestation_from_cert(cert)
417+
{
418+
Ok(attestation) => Some(attestation),
419+
Err(err) => {
420+
warn!(
421+
"Failed to extract remote attestation from inner-session certificate: {err}"
422+
);
423+
None
424+
}
425+
}
426+
}),
418427
None => None,
419428
}
420429
};
@@ -435,11 +444,15 @@ impl ProxyServer {
435444
let (_io, server_connection) = tls_stream.get_ref();
436445

437446
match server_connection.peer_certificates() {
438-
Some(remote_cert_chain) => Some(
439-
AttestedCertificateVerifier::extract_custom_attestation_from_cert(
440-
remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?,
441-
)?,
442-
),
447+
Some(remote_cert_chain) => remote_cert_chain.first().and_then(|cert| {
448+
match AttestedCertificateVerifier::extract_custom_attestation_from_cert(cert) {
449+
Ok(attestation) => Some(attestation),
450+
Err(err) => {
451+
warn!("Failed to extract remote attestation from certificate: {err}");
452+
None
453+
}
454+
}
455+
}),
443456
None => None,
444457
}
445458
};
@@ -461,7 +474,13 @@ impl ProxyServer {
461474
let (remote_attestation_type, measurements) = match attestation {
462475
Some(attestation) => (
463476
Some(attestation.attestation_type),
464-
attestation.get_measurements()?,
477+
match attestation.get_measurements() {
478+
Ok(measurements) => measurements,
479+
Err(err) => {
480+
warn!("Failed to extract measurements from peer attestation: {err}");
481+
None
482+
}
483+
},
465484
),
466485
None => (None, None),
467486
};
@@ -715,7 +734,7 @@ impl ProxyClient {
715734
let mut first = true;
716735
let mut ready_tx = Some(ready_tx);
717736
'reconnect: loop {
718-
let (mut sender, conn, attestation) =
737+
let (mut sender, conn) =
719738
// Connect to the proxy server and provide / verify attestation
720739
match Self::setup_connection_with_backoff(&target, &nesting_tls_connector, first)
721740
.await
@@ -745,9 +764,6 @@ impl ProxyClient {
745764
let (conn_done_tx, mut conn_done_rx) =
746765
tokio::sync::watch::channel::<Option<hyper::Error>>(None);
747766

748-
let mut remote_attestation_type = attestation.attestation_type;
749-
let mut measurements = attestation.get_measurements().ok().flatten();
750-
751767
tokio::spawn(async move {
752768
let res = conn.await;
753769
let _ = conn_done_tx.send(res.err());
@@ -760,45 +776,60 @@ impl ProxyClient {
760776
debug!("[proxy-client] Read incoming request from source client: {req:?}");
761777
// Attempt to forward it to the proxy server
762778
let response = loop {
763-
match sender.send_request(req.clone()).await {
764-
Ok(mut resp) => {
765-
debug!("[proxy-client] Read response from proxy-server: {resp:?}");
766-
// If we have measurements from the proxy-server, inject them into the
767-
// response header
768-
let headers = resp.headers_mut();
769-
if let Some(measurements) = measurements.clone() {
770-
match measurements.to_header_format() {
771-
Ok(header_value) => {
772-
headers.insert(MEASUREMENT_HEADER, header_value);
779+
let send_result = tokio::select! {
780+
result = sender.send_request(req.clone()) => result,
781+
_ = conn_done_rx.changed() => {
782+
warn!("Connection dropped while request was in flight");
783+
match Self::setup_connection_with_backoff(
784+
&target,
785+
&nesting_tls_connector,
786+
true,
787+
)
788+
.await
789+
{
790+
Ok((new_sender, new_conn)) => {
791+
sender = new_sender;
792+
793+
let (new_conn_done_tx, new_conn_done_rx) =
794+
tokio::sync::watch::channel::<Option<hyper::Error>>(None);
795+
conn_done_rx = new_conn_done_rx;
796+
797+
tokio::spawn(async move {
798+
let res = new_conn.await;
799+
let _ = new_conn_done_tx.send(res.err());
800+
});
801+
802+
warn!("Reconnected to proxy-server, retrying request");
803+
continue;
773804
}
774-
Err(e) => {
775-
// This error is highly unlikely - that the measurement values fail to
776-
// encode to JSON or fit in an HTTP header
777-
error!("Failed to encode measurement values: {e}");
805+
Err(reconnect_err) => {
806+
warn!("Reconnect after in-flight drop failed: {reconnect_err}");
807+
let mut resp = Response::new(full(
808+
"Request failed: connection to proxy-server dropped",
809+
));
810+
*resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
811+
break Ok(resp);
778812
}
779813
}
780814
}
815+
};
781816

782-
update_header(
783-
headers,
784-
ATTESTATION_TYPE_HEADER,
785-
remote_attestation_type.as_str(),
786-
);
817+
match send_result {
818+
Ok(resp) => {
819+
debug!("[proxy-client] Read response from proxy-server: {resp:?}");
787820
break Ok(resp.map(|b| b.boxed()));
788821
}
789822
Err(e) => {
790-
warn!("Failed to send request to proxy-server: {e}");
823+
warn!("Failed to send request to proxy-server: {e}");
791824
match Self::setup_connection_with_backoff(
792825
&target,
793826
&nesting_tls_connector,
794-
false,
827+
true,
795828
)
796829
.await
797830
{
798-
Ok((new_sender, new_conn, new_attestation)) => {
831+
Ok((new_sender, new_conn)) => {
799832
sender = new_sender;
800-
remote_attestation_type = new_attestation.attestation_type;
801-
measurements = new_attestation.get_measurements().ok().flatten();
802833

803834
let (new_conn_done_tx, new_conn_done_rx) =
804835
tokio::sync::watch::channel::<Option<hyper::Error>>(None);
@@ -915,7 +946,7 @@ impl ProxyClient {
915946
target: &str,
916947
nesting_tls_connector: &NestingTlsConnector,
917948
should_bail: bool,
918-
) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> {
949+
) -> Result<(HttpSender, HttpConnection), ProxyError> {
919950
let mut delay = Duration::from_secs(1);
920951
let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS);
921952

@@ -944,7 +975,7 @@ impl ProxyClient {
944975
async fn setup_connection(
945976
nesting_tls_connector: &NestingTlsConnector,
946977
target: &str,
947-
) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> {
978+
) -> Result<(HttpSender, HttpConnection), ProxyError> {
948979
let outbound_stream = tokio::net::TcpStream::connect(target).await?;
949980

950981
let domain = server_name_from_host(target)?;
@@ -954,19 +985,6 @@ impl ProxyClient {
954985

955986
debug!("[proxy-client] Connected to proxy server");
956987

957-
// Get attestation from session
958-
let attestation = {
959-
let (_io, server_connection) = tls_stream.get_ref();
960-
961-
let remote_cert_chain = server_connection
962-
.peer_certificates()
963-
.ok_or(ProxyError::NoCertificate)?;
964-
965-
AttestedCertificateVerifier::extract_custom_attestation_from_cert(
966-
remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?,
967-
)?
968-
};
969-
970988
// The attestation exchange is now complete - setup an HTTP client
971989
let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream);
972990

@@ -990,8 +1008,7 @@ impl ProxyClient {
9901008
}
9911009
};
9921010

993-
// Return the HTTP client, as well as remote attestation
994-
Ok((sender, conn, attestation))
1011+
Ok((sender, conn))
9951012
}
9961013

9971014
// Handle a request from the source client to the proxy server
@@ -1364,7 +1381,7 @@ mod tests {
13641381
let nesting_tls_connector =
13651382
NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config));
13661383

1367-
let (sender, conn, _attestation) = ProxyClient::setup_connection(
1384+
let (sender, conn) = ProxyClient::setup_connection(
13681385
&nesting_tls_connector,
13691386
&format!("localhost:{}", proxy_addr.port()),
13701387
)

0 commit comments

Comments
 (0)