Skip to content

Commit 7ed43a9

Browse files
committed
Fully restore measurement header injection
1 parent d45a353 commit 7ed43a9

3 files changed

Lines changed: 121 additions & 23 deletions

File tree

Cargo.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/lib.rs

Lines changed: 110 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ impl ProxyClient {
734734
let mut first = true;
735735
let mut ready_tx = Some(ready_tx);
736736
'reconnect: loop {
737-
let (mut sender, conn) =
737+
let (mut sender, conn, attestation) =
738738
// Connect to the proxy server and provide / verify attestation
739739
match Self::setup_connection_with_backoff(&target, &nesting_tls_connector, first)
740740
.await
@@ -764,6 +764,9 @@ impl ProxyClient {
764764
let (conn_done_tx, mut conn_done_rx) =
765765
tokio::sync::watch::channel::<Option<hyper::Error>>(None);
766766

767+
let mut remote_attestation_type = attestation.attestation_type;
768+
let mut measurements = attestation.get_measurements().ok().flatten();
769+
767770
tokio::spawn(async move {
768771
let res = conn.await;
769772
let _ = conn_done_tx.send(res.err());
@@ -787,8 +790,10 @@ impl ProxyClient {
787790
)
788791
.await
789792
{
790-
Ok((new_sender, new_conn)) => {
793+
Ok((new_sender, new_conn, new_attestation)) => {
791794
sender = new_sender;
795+
remote_attestation_type = new_attestation.attestation_type;
796+
measurements = new_attestation.get_measurements().ok().flatten();
792797

793798
let (new_conn_done_tx, new_conn_done_rx) =
794799
tokio::sync::watch::channel::<Option<hyper::Error>>(None);
@@ -815,8 +820,26 @@ impl ProxyClient {
815820
};
816821

817822
match send_result {
818-
Ok(resp) => {
823+
Ok(mut resp) => {
819824
debug!("[proxy-client] Read response from proxy-server: {resp:?}");
825+
let headers = resp.headers_mut();
826+
if let Some(measurements) = measurements.clone() {
827+
match measurements.to_header_format() {
828+
Ok(header_value) => {
829+
headers.insert(MEASUREMENT_HEADER, header_value);
830+
}
831+
Err(e) => {
832+
error!("Failed to encode measurement values: {e}");
833+
}
834+
}
835+
}
836+
837+
update_header(
838+
headers,
839+
ATTESTATION_TYPE_HEADER,
840+
remote_attestation_type.as_str(),
841+
);
842+
820843
break Ok(resp.map(|b| b.boxed()));
821844
}
822845
Err(e) => {
@@ -828,8 +851,10 @@ impl ProxyClient {
828851
)
829852
.await
830853
{
831-
Ok((new_sender, new_conn)) => {
854+
Ok((new_sender, new_conn, new_attestation)) => {
832855
sender = new_sender;
856+
remote_attestation_type = new_attestation.attestation_type;
857+
measurements = new_attestation.get_measurements().ok().flatten();
833858

834859
let (new_conn_done_tx, new_conn_done_rx) =
835860
tokio::sync::watch::channel::<Option<hyper::Error>>(None);
@@ -946,7 +971,7 @@ impl ProxyClient {
946971
target: &str,
947972
nesting_tls_connector: &NestingTlsConnector,
948973
should_bail: bool,
949-
) -> Result<(HttpSender, HttpConnection), ProxyError> {
974+
) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> {
950975
let mut delay = Duration::from_secs(1);
951976
let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS);
952977

@@ -975,7 +1000,7 @@ impl ProxyClient {
9751000
async fn setup_connection(
9761001
nesting_tls_connector: &NestingTlsConnector,
9771002
target: &str,
978-
) -> Result<(HttpSender, HttpConnection), ProxyError> {
1003+
) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> {
9791004
let outbound_stream = tokio::net::TcpStream::connect(target).await?;
9801005

9811006
let domain = server_name_from_host(target)?;
@@ -985,6 +1010,18 @@ impl ProxyClient {
9851010

9861011
debug!("[proxy-client] Connected to proxy server");
9871012

1013+
let attestation = {
1014+
let (_io, server_connection) = tls_stream.get_ref();
1015+
1016+
let remote_cert_chain = server_connection
1017+
.peer_certificates()
1018+
.ok_or(ProxyError::NoCertificate)?;
1019+
1020+
AttestedCertificateVerifier::extract_custom_attestation_from_cert(
1021+
remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?,
1022+
)?
1023+
};
1024+
9881025
// The attestation exchange is now complete - setup an HTTP client
9891026
let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream);
9901027

@@ -1008,7 +1045,7 @@ impl ProxyClient {
10081045
}
10091046
};
10101047

1011-
Ok((sender, conn))
1048+
Ok((sender, conn, attestation))
10121049
}
10131050

10141051
// Handle a request from the source client to the proxy server
@@ -1207,6 +1244,7 @@ where
12071244
#[cfg(test)]
12081245
mod tests {
12091246
use attestation::{AttestationType, measurements::MeasurementPolicy};
1247+
use std::collections::HashMap;
12101248
use tokio_rustls::TlsConnector;
12111249

12121250
use super::*;
@@ -1215,6 +1253,43 @@ mod tests {
12151253
generate_tls_config_with_client_auth, init_tracing,
12161254
};
12171255

1256+
fn expected_mock_measurements() -> HashMap<String, String> {
1257+
let zero_measurement = "0".repeat(96);
1258+
HashMap::from([
1259+
("0".to_string(), zero_measurement.clone()),
1260+
("1".to_string(), zero_measurement.clone()),
1261+
("2".to_string(), zero_measurement.clone()),
1262+
("3".to_string(), zero_measurement.clone()),
1263+
("4".to_string(), zero_measurement),
1264+
])
1265+
}
1266+
1267+
fn assert_mock_measurements(body: &str) {
1268+
let parsed: HashMap<String, String> = serde_json::from_str(body).unwrap();
1269+
assert_eq!(parsed, expected_mock_measurements());
1270+
}
1271+
1272+
fn assert_mock_measurements_header(headers: &http::HeaderMap) {
1273+
let body = headers
1274+
.get(MEASUREMENT_HEADER)
1275+
.and_then(|v| v.to_str().ok())
1276+
.unwrap();
1277+
assert_mock_measurements(body);
1278+
}
1279+
1280+
fn assert_attestation_type_header(headers: &http::HeaderMap, expected: &str) {
1281+
assert_eq!(
1282+
headers
1283+
.get(ATTESTATION_TYPE_HEADER)
1284+
.and_then(|v| v.to_str().ok()),
1285+
Some(expected)
1286+
);
1287+
}
1288+
1289+
fn assert_no_measurements_header(headers: &http::HeaderMap) {
1290+
assert!(headers.get(MEASUREMENT_HEADER).is_none());
1291+
}
1292+
12181293
#[test]
12191294
fn proxy_alpn_protocols_prefer_http2() {
12201295
let mut protocols = Vec::new();
@@ -1381,7 +1456,7 @@ mod tests {
13811456
let nesting_tls_connector =
13821457
NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config));
13831458

1384-
let (sender, conn) = ProxyClient::setup_connection(
1459+
let (sender, conn, _attestation) = ProxyClient::setup_connection(
13851460
&nesting_tls_connector,
13861461
&format!("localhost:{}", proxy_addr.port()),
13871462
)
@@ -1445,6 +1520,9 @@ mod tests {
14451520
.await
14461521
.unwrap();
14471522

1523+
assert_attestation_type_header(res.headers(), "dcap-tdx");
1524+
assert_mock_measurements_header(res.headers());
1525+
14481526
let res_body = res.text().await.unwrap();
14491527
assert_eq!(res_body, "No measurements");
14501528
}
@@ -1513,8 +1591,11 @@ mod tests {
15131591
.await
15141592
.unwrap();
15151593

1594+
assert_attestation_type_header(res.headers(), "none");
1595+
assert_no_measurements_header(res.headers());
1596+
15161597
let res_body = res.text().await.unwrap();
1517-
assert_eq!(res_body, "No measurements");
1598+
assert_mock_measurements(&res_body);
15181599
}
15191600

15201601
// Server has no attestation, client has mock DCAP but no client auth
@@ -1574,7 +1655,11 @@ mod tests {
15741655
.await
15751656
.unwrap();
15761657

1577-
let _res_body = res.text().await.unwrap();
1658+
assert_attestation_type_header(res.headers(), "none");
1659+
assert_no_measurements_header(res.headers());
1660+
1661+
let res_body = res.text().await.unwrap();
1662+
assert_eq!(res_body, "No measurements");
15781663
}
15791664

15801665
// Server has mock DCAP, client has mock DCAP and client auth
@@ -1641,12 +1726,16 @@ mod tests {
16411726
let res = reqwest::get(format!("http://{}", proxy_client_addr))
16421727
.await
16431728
.unwrap();
1644-
assert_eq!(res.text().await.unwrap(), "No measurements");
1729+
assert_attestation_type_header(res.headers(), "dcap-tdx");
1730+
assert_mock_measurements_header(res.headers());
1731+
assert_mock_measurements(&res.text().await.unwrap());
16451732

16461733
let res = reqwest::get(format!("http://{}", proxy_client_addr))
16471734
.await
16481735
.unwrap();
1649-
assert_eq!(res.text().await.unwrap(), "No measurements");
1736+
assert_attestation_type_header(res.headers(), "dcap-tdx");
1737+
assert_mock_measurements_header(res.headers());
1738+
assert_mock_measurements(&res.text().await.unwrap());
16501739
}
16511740

16521741
// Server has mock DCAP, client no attestation - just get the server certificate
@@ -1874,9 +1963,11 @@ mod tests {
18741963
proxy_client.accept().await.unwrap();
18751964
});
18761965

1877-
let _initial_response = reqwest::get(format!("http://{}", proxy_client_addr))
1966+
let initial_response = reqwest::get(format!("http://{}", proxy_client_addr))
18781967
.await
18791968
.unwrap();
1969+
assert_attestation_type_header(initial_response.headers(), "dcap-tdx");
1970+
assert_mock_measurements_header(initial_response.headers());
18801971

18811972
// Now break the connection
18821973
connection_breaker_tx.send(()).unwrap();
@@ -1886,6 +1977,9 @@ mod tests {
18861977
.await
18871978
.unwrap();
18881979

1980+
assert_attestation_type_header(res.headers(), "dcap-tdx");
1981+
assert_mock_measurements_header(res.headers());
1982+
18891983
let res_body = res.text().await.unwrap();
18901984
assert_eq!(res_body, "No measurements");
18911985
}
@@ -1945,6 +2039,9 @@ mod tests {
19452039
.await
19462040
.unwrap();
19472041

2042+
assert_attestation_type_header(res.headers(), "dcap-tdx");
2043+
assert_mock_measurements_header(res.headers());
2044+
19482045
let res_body = res.text().await.unwrap();
19492046
assert_eq!(res_body, "No measurements");
19502047
}

src/test_helpers.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ use tokio_rustls::rustls::{
1212
};
1313
use tracing_subscriber::{EnvFilter, fmt};
1414

15+
use crate::MEASUREMENT_HEADER;
16+
1517
static INIT: Once = Once::new();
1618

1719
/// Helper to generate a self-signed certificate for testing with a DNS subject name
@@ -127,13 +129,12 @@ pub async fn example_http_service() -> SocketAddr {
127129
addr
128130
}
129131

130-
async fn get_handler(_headers: http::HeaderMap) -> impl IntoResponse {
131-
// headers
132-
// .get(MEASUREMENT_HEADER)
133-
// .and_then(|v| v.to_str().ok())
134-
// .unwrap_or("No measurements")
135-
// .to_string()
136-
"No measurements".to_string()
132+
async fn get_handler(headers: http::HeaderMap) -> impl IntoResponse {
133+
headers
134+
.get(MEASUREMENT_HEADER)
135+
.and_then(|v| v.to_str().ok())
136+
.unwrap_or("No measurements")
137+
.to_string()
137138
}
138139

139140
pub fn init_tracing() {

0 commit comments

Comments
 (0)