@@ -734,7 +734,7 @@ impl ProxyClient {
734734 let mut first = true ;
735735 let mut ready_tx = Some ( ready_tx) ;
736736 ' reconnect: loop {
737- let ( mut sender, conn) =
737+ let ( mut sender, conn, attestation ) =
738738 // Connect to the proxy server and provide / verify attestation
739739 match Self :: setup_connection_with_backoff ( & target, & nesting_tls_connector, first)
740740 . await
@@ -764,6 +764,9 @@ impl ProxyClient {
764764 let ( conn_done_tx, mut conn_done_rx) =
765765 tokio:: sync:: watch:: channel :: < Option < hyper:: Error > > ( None ) ;
766766
767+ let mut remote_attestation_type = attestation. attestation_type ;
768+ let mut measurements = attestation. get_measurements ( ) . ok ( ) . flatten ( ) ;
769+
767770 tokio:: spawn ( async move {
768771 let res = conn. await ;
769772 let _ = conn_done_tx. send ( res. err ( ) ) ;
@@ -787,8 +790,10 @@ impl ProxyClient {
787790 )
788791 . await
789792 {
790- Ok ( ( new_sender, new_conn) ) => {
793+ Ok ( ( new_sender, new_conn, new_attestation ) ) => {
791794 sender = new_sender;
795+ remote_attestation_type = new_attestation. attestation_type;
796+ measurements = new_attestation. get_measurements( ) . ok( ) . flatten( ) ;
792797
793798 let ( new_conn_done_tx, new_conn_done_rx) =
794799 tokio:: sync:: watch:: channel:: <Option <hyper:: Error >>( None ) ;
@@ -815,8 +820,26 @@ impl ProxyClient {
815820 } ;
816821
817822 match send_result {
818- Ok ( resp) => {
823+ Ok ( mut resp) => {
819824 debug!( "[proxy-client] Read response from proxy-server: {resp:?}" ) ;
825+ let headers = resp. headers_mut( ) ;
826+ if let Some ( measurements) = measurements. clone( ) {
827+ match measurements. to_header_format( ) {
828+ Ok ( header_value) => {
829+ headers. insert( MEASUREMENT_HEADER , header_value) ;
830+ }
831+ Err ( e) => {
832+ error!( "Failed to encode measurement values: {e}" ) ;
833+ }
834+ }
835+ }
836+
837+ update_header(
838+ headers,
839+ ATTESTATION_TYPE_HEADER ,
840+ remote_attestation_type. as_str( ) ,
841+ ) ;
842+
820843 break Ok ( resp. map( |b| b. boxed( ) ) ) ;
821844 }
822845 Err ( e) => {
@@ -828,8 +851,10 @@ impl ProxyClient {
828851 )
829852 . await
830853 {
831- Ok ( ( new_sender, new_conn) ) => {
854+ Ok ( ( new_sender, new_conn, new_attestation ) ) => {
832855 sender = new_sender;
856+ remote_attestation_type = new_attestation. attestation_type;
857+ measurements = new_attestation. get_measurements( ) . ok( ) . flatten( ) ;
833858
834859 let ( new_conn_done_tx, new_conn_done_rx) =
835860 tokio:: sync:: watch:: channel:: <Option <hyper:: Error >>( None ) ;
@@ -946,7 +971,7 @@ impl ProxyClient {
946971 target : & str ,
947972 nesting_tls_connector : & NestingTlsConnector ,
948973 should_bail : bool ,
949- ) -> Result < ( HttpSender , HttpConnection ) , ProxyError > {
974+ ) -> Result < ( HttpSender , HttpConnection , AttestationExchangeMessage ) , ProxyError > {
950975 let mut delay = Duration :: from_secs ( 1 ) ;
951976 let max_delay = Duration :: from_secs ( SERVER_RECONNECT_MAX_BACKOFF_SECS ) ;
952977
@@ -975,7 +1000,7 @@ impl ProxyClient {
9751000 async fn setup_connection (
9761001 nesting_tls_connector : & NestingTlsConnector ,
9771002 target : & str ,
978- ) -> Result < ( HttpSender , HttpConnection ) , ProxyError > {
1003+ ) -> Result < ( HttpSender , HttpConnection , AttestationExchangeMessage ) , ProxyError > {
9791004 let outbound_stream = tokio:: net:: TcpStream :: connect ( target) . await ?;
9801005
9811006 let domain = server_name_from_host ( target) ?;
@@ -985,6 +1010,18 @@ impl ProxyClient {
9851010
9861011 debug ! ( "[proxy-client] Connected to proxy server" ) ;
9871012
1013+ let attestation = {
1014+ let ( _io, server_connection) = tls_stream. get_ref ( ) ;
1015+
1016+ let remote_cert_chain = server_connection
1017+ . peer_certificates ( )
1018+ . ok_or ( ProxyError :: NoCertificate ) ?;
1019+
1020+ AttestedCertificateVerifier :: extract_custom_attestation_from_cert (
1021+ remote_cert_chain. first ( ) . ok_or ( ProxyError :: NoCertificate ) ?,
1022+ ) ?
1023+ } ;
1024+
9881025 // The attestation exchange is now complete - setup an HTTP client
9891026 let http_version = HttpVersion :: from_negotiated_protocol_client ( & tls_stream) ;
9901027
@@ -1008,7 +1045,7 @@ impl ProxyClient {
10081045 }
10091046 } ;
10101047
1011- Ok ( ( sender, conn) )
1048+ Ok ( ( sender, conn, attestation ) )
10121049 }
10131050
10141051 // Handle a request from the source client to the proxy server
@@ -1207,6 +1244,7 @@ where
12071244#[ cfg( test) ]
12081245mod tests {
12091246 use attestation:: { AttestationType , measurements:: MeasurementPolicy } ;
1247+ use std:: collections:: HashMap ;
12101248 use tokio_rustls:: TlsConnector ;
12111249
12121250 use super :: * ;
@@ -1215,6 +1253,43 @@ mod tests {
12151253 generate_tls_config_with_client_auth, init_tracing,
12161254 } ;
12171255
1256+ fn expected_mock_measurements ( ) -> HashMap < String , String > {
1257+ let zero_measurement = "0" . repeat ( 96 ) ;
1258+ HashMap :: from ( [
1259+ ( "0" . to_string ( ) , zero_measurement. clone ( ) ) ,
1260+ ( "1" . to_string ( ) , zero_measurement. clone ( ) ) ,
1261+ ( "2" . to_string ( ) , zero_measurement. clone ( ) ) ,
1262+ ( "3" . to_string ( ) , zero_measurement. clone ( ) ) ,
1263+ ( "4" . to_string ( ) , zero_measurement) ,
1264+ ] )
1265+ }
1266+
1267+ fn assert_mock_measurements ( body : & str ) {
1268+ let parsed: HashMap < String , String > = serde_json:: from_str ( body) . unwrap ( ) ;
1269+ assert_eq ! ( parsed, expected_mock_measurements( ) ) ;
1270+ }
1271+
1272+ fn assert_mock_measurements_header ( headers : & http:: HeaderMap ) {
1273+ let body = headers
1274+ . get ( MEASUREMENT_HEADER )
1275+ . and_then ( |v| v. to_str ( ) . ok ( ) )
1276+ . unwrap ( ) ;
1277+ assert_mock_measurements ( body) ;
1278+ }
1279+
1280+ fn assert_attestation_type_header ( headers : & http:: HeaderMap , expected : & str ) {
1281+ assert_eq ! (
1282+ headers
1283+ . get( ATTESTATION_TYPE_HEADER )
1284+ . and_then( |v| v. to_str( ) . ok( ) ) ,
1285+ Some ( expected)
1286+ ) ;
1287+ }
1288+
1289+ fn assert_no_measurements_header ( headers : & http:: HeaderMap ) {
1290+ assert ! ( headers. get( MEASUREMENT_HEADER ) . is_none( ) ) ;
1291+ }
1292+
12181293 #[ test]
12191294 fn proxy_alpn_protocols_prefer_http2 ( ) {
12201295 let mut protocols = Vec :: new ( ) ;
@@ -1381,7 +1456,7 @@ mod tests {
13811456 let nesting_tls_connector =
13821457 NestingTlsConnector :: new ( Arc :: new ( outer_client_config) , Arc :: new ( inner_client_config) ) ;
13831458
1384- let ( sender, conn) = ProxyClient :: setup_connection (
1459+ let ( sender, conn, _attestation ) = ProxyClient :: setup_connection (
13851460 & nesting_tls_connector,
13861461 & format ! ( "localhost:{}" , proxy_addr. port( ) ) ,
13871462 )
@@ -1445,6 +1520,9 @@ mod tests {
14451520 . await
14461521 . unwrap ( ) ;
14471522
1523+ assert_attestation_type_header ( res. headers ( ) , "dcap-tdx" ) ;
1524+ assert_mock_measurements_header ( res. headers ( ) ) ;
1525+
14481526 let res_body = res. text ( ) . await . unwrap ( ) ;
14491527 assert_eq ! ( res_body, "No measurements" ) ;
14501528 }
@@ -1513,8 +1591,11 @@ mod tests {
15131591 . await
15141592 . unwrap ( ) ;
15151593
1594+ assert_attestation_type_header ( res. headers ( ) , "none" ) ;
1595+ assert_no_measurements_header ( res. headers ( ) ) ;
1596+
15161597 let res_body = res. text ( ) . await . unwrap ( ) ;
1517- assert_eq ! ( res_body, "No measurements" ) ;
1598+ assert_mock_measurements ( & res_body) ;
15181599 }
15191600
15201601 // Server has no attestation, client has mock DCAP but no client auth
@@ -1574,7 +1655,11 @@ mod tests {
15741655 . await
15751656 . unwrap ( ) ;
15761657
1577- let _res_body = res. text ( ) . await . unwrap ( ) ;
1658+ assert_attestation_type_header ( res. headers ( ) , "none" ) ;
1659+ assert_no_measurements_header ( res. headers ( ) ) ;
1660+
1661+ let res_body = res. text ( ) . await . unwrap ( ) ;
1662+ assert_eq ! ( res_body, "No measurements" ) ;
15781663 }
15791664
15801665 // Server has mock DCAP, client has mock DCAP and client auth
@@ -1641,12 +1726,16 @@ mod tests {
16411726 let res = reqwest:: get ( format ! ( "http://{}" , proxy_client_addr) )
16421727 . await
16431728 . unwrap ( ) ;
1644- assert_eq ! ( res. text( ) . await . unwrap( ) , "No measurements" ) ;
1729+ assert_attestation_type_header ( res. headers ( ) , "dcap-tdx" ) ;
1730+ assert_mock_measurements_header ( res. headers ( ) ) ;
1731+ assert_mock_measurements ( & res. text ( ) . await . unwrap ( ) ) ;
16451732
16461733 let res = reqwest:: get ( format ! ( "http://{}" , proxy_client_addr) )
16471734 . await
16481735 . unwrap ( ) ;
1649- assert_eq ! ( res. text( ) . await . unwrap( ) , "No measurements" ) ;
1736+ assert_attestation_type_header ( res. headers ( ) , "dcap-tdx" ) ;
1737+ assert_mock_measurements_header ( res. headers ( ) ) ;
1738+ assert_mock_measurements ( & res. text ( ) . await . unwrap ( ) ) ;
16501739 }
16511740
16521741 // Server has mock DCAP, client no attestation - just get the server certificate
@@ -1874,9 +1963,11 @@ mod tests {
18741963 proxy_client. accept ( ) . await . unwrap ( ) ;
18751964 } ) ;
18761965
1877- let _initial_response = reqwest:: get ( format ! ( "http://{}" , proxy_client_addr) )
1966+ let initial_response = reqwest:: get ( format ! ( "http://{}" , proxy_client_addr) )
18781967 . await
18791968 . unwrap ( ) ;
1969+ assert_attestation_type_header ( initial_response. headers ( ) , "dcap-tdx" ) ;
1970+ assert_mock_measurements_header ( initial_response. headers ( ) ) ;
18801971
18811972 // Now break the connection
18821973 connection_breaker_tx. send ( ( ) ) . unwrap ( ) ;
@@ -1886,6 +1977,9 @@ mod tests {
18861977 . await
18871978 . unwrap ( ) ;
18881979
1980+ assert_attestation_type_header ( res. headers ( ) , "dcap-tdx" ) ;
1981+ assert_mock_measurements_header ( res. headers ( ) ) ;
1982+
18891983 let res_body = res. text ( ) . await . unwrap ( ) ;
18901984 assert_eq ! ( res_body, "No measurements" ) ;
18911985 }
@@ -1945,6 +2039,9 @@ mod tests {
19452039 . await
19462040 . unwrap ( ) ;
19472041
2042+ assert_attestation_type_header ( res. headers ( ) , "dcap-tdx" ) ;
2043+ assert_mock_measurements_header ( res. headers ( ) ) ;
2044+
19482045 let res_body = res. text ( ) . await . unwrap ( ) ;
19492046 assert_eq ! ( res_body, "No measurements" ) ;
19502047 }
0 commit comments