@@ -52,6 +52,9 @@ type RequestWithResponseSender = (
5252 oneshot:: Sender < Result < Response < BoxBody < bytes:: Bytes , hyper:: Error > > , hyper:: Error > > ,
5353) ;
5454
55+ type OuterProxySession = ( Arc < TcpListener > , NestingTlsAcceptor ) ;
56+ type InnerProxySession = ( Arc < TcpListener > , TlsAcceptor ) ;
57+
5558/// TLS Credentials
5659pub struct TlsCertAndKey {
5760 /// Der-encoded TLS certificate chain
@@ -207,27 +210,30 @@ pub async fn get_inner_tls_cert_with_config(
207210
208211/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address
209212pub struct ProxyServer {
210- outer_listener : Option < Arc < TcpListener > > ,
211- outer_tls_acceptor : Option < NestingTlsAcceptor > ,
212- inner_listener : Arc < TcpListener > ,
213- inner_tls_acceptor : TlsAcceptor ,
213+ outer : Option < OuterProxySession > ,
214+ inner : Option < InnerProxySession > ,
214215 /// The address/hostname of the target service we are proxying to
215216 target : String ,
216217}
217218
218219impl ProxyServer {
219220 /// Start with dual listeners. The outer nested-TLS listener is optional.
220- pub async fn new < O > (
221+ pub async fn new < O , I > (
221222 outer_session : Option < OuterTlsConfig < O > > ,
222- inner_local : impl ToSocketAddrs ,
223+ inner_local : Option < I > ,
223224 target : String ,
224225 attestation_generator : AttestationGenerator ,
225226 attestation_verifier : AttestationVerifier ,
226227 client_auth : bool ,
227228 ) -> Result < Self , ProxyError >
228229 where
229230 O : ToSocketAddrs ,
231+ I : ToSocketAddrs ,
230232 {
233+ if outer_session. is_none ( ) && inner_local. is_none ( ) {
234+ return Err ( ProxyError :: NoListenersConfigured ) ;
235+ }
236+
231237 let certificate_name = outer_session
232238 . as_ref ( )
233239 . map ( OuterTlsConfig :: certificate_name)
@@ -241,24 +247,28 @@ impl ProxyServer {
241247 )
242248 . await ?,
243249 ) ;
244- let inner_listener = Arc :: new ( TcpListener :: bind ( inner_local) . await ?) ;
245- let inner_tls_acceptor = TlsAcceptor :: from ( inner_server_config. clone ( ) ) ;
250+ let inner = match inner_local {
251+ Some ( inner_local) => {
252+ let inner_listener = Arc :: new ( TcpListener :: bind ( inner_local) . await ?) ;
253+ let inner_tls_acceptor = TlsAcceptor :: from ( inner_server_config. clone ( ) ) ;
254+ Some ( ( inner_listener, inner_tls_acceptor) )
255+ }
256+ None => None ,
257+ } ;
246258
247- let ( outer_listener , outer_tls_acceptor ) = match outer_session {
259+ let outer = match outer_session {
248260 Some ( outer_session) => {
249261 let ( outer_listener, outer_tls_acceptor) = outer_session
250262 . into_listener_and_acceptor ( inner_server_config. clone ( ) , client_auth)
251263 . await ?;
252- ( Some ( outer_listener) , Some ( outer_tls_acceptor) )
264+ Some ( ( outer_listener, outer_tls_acceptor) )
253265 }
254- None => ( None , None ) ,
266+ None => None ,
255267 } ;
256268
257269 Ok ( Self {
258- outer_listener,
259- outer_tls_acceptor,
260- inner_listener,
261- inner_tls_acceptor,
270+ outer,
271+ inner,
262272 target,
263273 } )
264274 }
@@ -268,13 +278,14 @@ impl ProxyServer {
268278 /// Returns the handle for the task handling the connection
269279 pub async fn accept ( & self ) -> Result < tokio:: task:: JoinHandle < ( ) > , ProxyError > {
270280 let target = self . target . clone ( ) ;
271- let outer_listener = self . outer_listener . clone ( ) ;
272- let outer_tls_acceptor = self . outer_tls_acceptor . clone ( ) ;
273- let inner_listener = self . inner_listener . clone ( ) ;
274- let inner_tls_acceptor = self . inner_tls_acceptor . clone ( ) ;
275-
276- let join_handle = match ( outer_listener, outer_tls_acceptor) {
277- ( Some ( outer_listener) , Some ( outer_tls_acceptor) ) => {
281+ let outer = self . outer . clone ( ) ;
282+ let inner = self . inner . clone ( ) ;
283+
284+ let join_handle = match ( outer, inner) {
285+ (
286+ Some ( ( outer_listener, outer_tls_acceptor) ) ,
287+ Some ( ( inner_listener, inner_tls_acceptor) ) ,
288+ ) => {
278289 let ( ( inbound, client_addr) , use_outer) = tokio:: select! {
279290 accepted = outer_listener. accept( ) => ( accepted?, true ) ,
280291 accepted = inner_listener. accept( ) => ( accepted?, false ) ,
@@ -312,7 +323,7 @@ impl ProxyServer {
312323 }
313324 } )
314325 }
315- _ => {
326+ ( None , Some ( ( inner_listener , inner_tls_acceptor ) ) ) => {
316327 let ( inbound, client_addr) = inner_listener. accept ( ) . await ?;
317328 tokio:: spawn ( async move {
318329 match inner_tls_acceptor. accept ( inbound) . await {
@@ -329,28 +340,54 @@ impl ProxyServer {
329340 }
330341 } )
331342 }
343+ ( Some ( ( outer_listener, outer_tls_acceptor) ) , None ) => {
344+ let ( inbound, client_addr) = outer_listener. accept ( ) . await ?;
345+ tokio:: spawn ( async move {
346+ match outer_tls_acceptor. accept ( inbound) . await {
347+ Ok ( tls_stream) => {
348+ if let Err ( err) =
349+ Self :: handle_outer_connection ( tls_stream, target, client_addr) . await
350+ {
351+ warn ! ( "Failed to handle outer connection: {err}" ) ;
352+ }
353+ }
354+ Err ( err) => {
355+ warn ! ( "Outer attestation exchange failed: {err}" ) ;
356+ }
357+ }
358+ } )
359+ }
360+ _ => return Err ( ProxyError :: NoListenersConfigured ) ,
332361 } ;
333362
334363 Ok ( join_handle)
335364 }
336365
337366 /// Helper to get the socket address of the underlying TCP listener
338367 pub fn local_addr ( & self ) -> std:: io:: Result < SocketAddr > {
339- match & self . outer_listener {
340- Some ( listener) => listener. local_addr ( ) ,
341- None => self . inner_listener . local_addr ( ) ,
368+ match & self . outer {
369+ Some ( ( listener, _) ) => listener. local_addr ( ) ,
370+ None => self
371+ . inner
372+ . as_ref ( )
373+ . map ( |( listener, _) | listener)
374+ . ok_or_else ( || std:: io:: Error :: other ( "no listeners configured" ) ) ?
375+ . local_addr ( ) ,
342376 }
343377 }
344378
345379 pub fn outer_local_addr ( & self ) -> std:: io:: Result < Option < SocketAddr > > {
346- self . outer_listener
380+ self . outer
347381 . as_ref ( )
348- . map ( |listener| listener. local_addr ( ) )
382+ . map ( |( listener, _ ) | listener. local_addr ( ) )
349383 . transpose ( )
350384 }
351385
352- pub fn inner_local_addr ( & self ) -> std:: io:: Result < SocketAddr > {
353- self . inner_listener . local_addr ( )
386+ pub fn inner_local_addr ( & self ) -> std:: io:: Result < Option < SocketAddr > > {
387+ self . inner
388+ . as_ref ( )
389+ . map ( |( listener, _) | listener. local_addr ( ) )
390+ . transpose ( )
354391 }
355392
356393 async fn handle_outer_connection (
@@ -909,6 +946,8 @@ pub enum ProxyError {
909946 MpscSend ,
910947 #[ error( "Client auth must be configured on both the inner and outer TLS sessions" ) ]
911948 ClientAuthMisconfigured ,
949+ #[ error( "At least one server listener must be configured" ) ]
950+ NoListenersConfigured ,
912951}
913952
914953impl From < mpsc:: error:: SendError < RequestWithResponseSender > > for ProxyError {
@@ -1039,6 +1078,21 @@ mod tests {
10391078 assert_eq ! ( protocols, vec![ ALPN_HTTP11 . to_vec( ) , ALPN_H2 . to_vec( ) ] ) ;
10401079 }
10411080
1081+ #[ tokio:: test( flavor = "multi_thread" ) ]
1082+ async fn proxy_server_requires_at_least_one_listener ( ) {
1083+ let result = ProxyServer :: new (
1084+ None :: < OuterTlsConfig < & str > > ,
1085+ None :: < & str > ,
1086+ "127.0.0.1:1" . to_string ( ) ,
1087+ AttestationGenerator :: with_no_attestation ( ) ,
1088+ AttestationVerifier :: expect_none ( ) ,
1089+ false ,
1090+ )
1091+ . await ;
1092+
1093+ assert ! ( matches!( result, Err ( ProxyError :: NoListenersConfigured ) ) ) ;
1094+ }
1095+
10421096 #[ tokio:: test( flavor = "multi_thread" ) ]
10431097 async fn dual_listener_server_reports_expected_addresses ( ) {
10441098 let target_addr = example_http_service ( ) . await ;
@@ -1054,7 +1108,7 @@ mod tests {
10541108 listen_addr : "127.0.0.1:0" ,
10551109 tls : OuterTlsMode :: CertAndKey ( tls_cert_and_key) ,
10561110 } ) ,
1057- "127.0.0.1:0" ,
1111+ Some ( "127.0.0.1:0" ) ,
10581112 target_addr. to_string ( ) ,
10591113 AttestationGenerator :: with_no_attestation ( ) ,
10601114 AttestationVerifier :: expect_none ( ) ,
@@ -1064,13 +1118,13 @@ mod tests {
10641118 . unwrap ( ) ;
10651119
10661120 let outer_addr = dual_listener_server. outer_local_addr ( ) . unwrap ( ) . unwrap ( ) ;
1067- let inner_addr = dual_listener_server. inner_local_addr ( ) . unwrap ( ) ;
1121+ let inner_addr = dual_listener_server. inner_local_addr ( ) . unwrap ( ) . unwrap ( ) ;
10681122 assert_eq ! ( dual_listener_server. local_addr( ) . unwrap( ) , outer_addr) ;
10691123 assert_ne ! ( outer_addr, inner_addr) ;
10701124
10711125 let inner_only_server = ProxyServer :: new (
10721126 None :: < OuterTlsConfig < & str > > ,
1073- "127.0.0.1:0" ,
1127+ Some ( "127.0.0.1:0" ) ,
10741128 target_addr. to_string ( ) ,
10751129 AttestationGenerator :: with_no_attestation ( ) ,
10761130 AttestationVerifier :: expect_none ( ) ,
@@ -1079,7 +1133,7 @@ mod tests {
10791133 . await
10801134 . unwrap ( ) ;
10811135
1082- let inner_only_addr = inner_only_server. inner_local_addr ( ) . unwrap ( ) ;
1136+ let inner_only_addr = inner_only_server. inner_local_addr ( ) . unwrap ( ) . unwrap ( ) ;
10831137 assert ! ( inner_only_server. outer_local_addr( ) . unwrap( ) . is_none( ) ) ;
10841138 assert_eq ! ( inner_only_server. local_addr( ) . unwrap( ) , inner_only_addr) ;
10851139 }
@@ -1091,7 +1145,7 @@ mod tests {
10911145
10921146 let proxy_server = ProxyServer :: new (
10931147 None :: < OuterTlsConfig < & str > > ,
1094- "127.0.0.1:0" ,
1148+ Some ( "127.0.0.1:0" ) ,
10951149 target_addr. to_string ( ) ,
10961150 AttestationGenerator :: new ( AttestationType :: DcapTdx , None ) . unwrap ( ) ,
10971151 AttestationVerifier :: expect_none ( ) ,
@@ -1100,7 +1154,7 @@ mod tests {
11001154 . await
11011155 . unwrap ( ) ;
11021156
1103- let inner_addr = proxy_server. inner_local_addr ( ) . unwrap ( ) ;
1157+ let inner_addr = proxy_server. inner_local_addr ( ) . unwrap ( ) . unwrap ( ) ;
11041158
11051159 tokio:: spawn ( async move {
11061160 proxy_server. accept ( ) . await . unwrap ( ) ;
@@ -1147,7 +1201,7 @@ mod tests {
11471201 certificate_name : certificate_identity_from_chain ( & cert_chain) . unwrap ( ) ,
11481202 } ,
11491203 } ) ,
1150- "127.0.0.1:0" ,
1204+ Some ( "127.0.0.1:0" ) ,
11511205 target_addr. to_string ( ) ,
11521206 AttestationGenerator :: new ( AttestationType :: DcapTdx , None ) . unwrap ( ) ,
11531207 AttestationVerifier :: expect_none ( ) ,
@@ -1254,7 +1308,7 @@ mod tests {
12541308 certificate_name : certificate_identity_from_chain ( & cert_chain) . unwrap ( ) ,
12551309 } ,
12561310 } ) ,
1257- "127.0.0.1:0" ,
1311+ Some ( "127.0.0.1:0" ) ,
12581312 target_addr. to_string ( ) ,
12591313 AttestationGenerator :: new ( AttestationType :: DcapTdx , None ) . unwrap ( ) ,
12601314 AttestationVerifier :: expect_none ( ) ,
@@ -1322,7 +1376,7 @@ mod tests {
13221376 certificate_name : certificate_identity_from_chain ( & server_cert_chain) . unwrap ( ) ,
13231377 } ,
13241378 } ) ,
1325- "127.0.0.1:0" ,
1379+ Some ( "127.0.0.1:0" ) ,
13261380 target_addr. to_string ( ) ,
13271381 AttestationGenerator :: with_no_attestation ( ) ,
13281382 AttestationVerifier :: mock ( ) ,
@@ -1380,7 +1434,7 @@ mod tests {
13801434 certificate_name : certificate_identity_from_chain ( & server_cert_chain) . unwrap ( ) ,
13811435 } ,
13821436 } ) ,
1383- "127.0.0.1:0" ,
1437+ Some ( "127.0.0.1:0" ) ,
13841438 target_addr. to_string ( ) ,
13851439 AttestationGenerator :: with_no_attestation ( ) ,
13861440 AttestationVerifier :: mock ( ) ,
@@ -1450,7 +1504,7 @@ mod tests {
14501504 certificate_name : certificate_identity_from_chain ( & server_cert_chain) . unwrap ( ) ,
14511505 } ,
14521506 } ) ,
1453- "127.0.0.1:0" ,
1507+ Some ( "127.0.0.1:0" ) ,
14541508 target_addr. to_string ( ) ,
14551509 AttestationGenerator :: new ( AttestationType :: DcapTdx , None ) . unwrap ( ) ,
14561510 AttestationVerifier :: mock ( ) ,
@@ -1510,7 +1564,7 @@ mod tests {
15101564 certificate_name : certificate_identity_from_chain ( & cert_chain) . unwrap ( ) ,
15111565 } ,
15121566 } ) ,
1513- "127.0.0.1:0" ,
1567+ Some ( "127.0.0.1:0" ) ,
15141568 target_addr. to_string ( ) ,
15151569 AttestationGenerator :: new ( AttestationType :: DcapTdx , None ) . unwrap ( ) ,
15161570 AttestationVerifier :: expect_none ( ) ,
@@ -1558,7 +1612,7 @@ mod tests {
15581612 certificate_name : certificate_identity_from_chain ( & cert_chain) . unwrap ( ) ,
15591613 } ,
15601614 } ) ,
1561- "127.0.0.1:0" ,
1615+ Some ( "127.0.0.1:0" ) ,
15621616 target_addr. to_string ( ) ,
15631617 AttestationGenerator :: with_no_attestation ( ) ,
15641618 AttestationVerifier :: expect_none ( ) ,
@@ -1604,7 +1658,7 @@ mod tests {
16041658 certificate_name : certificate_identity_from_chain ( & cert_chain) . unwrap ( ) ,
16051659 } ,
16061660 } ) ,
1607- "127.0.0.1:0" ,
1661+ Some ( "127.0.0.1:0" ) ,
16081662 target_addr. to_string ( ) ,
16091663 AttestationGenerator :: new ( AttestationType :: DcapTdx , None ) . unwrap ( ) ,
16101664 AttestationVerifier :: expect_none ( ) ,
@@ -1675,7 +1729,7 @@ mod tests {
16751729 certificate_name : certificate_identity_from_chain ( & cert_chain) . unwrap ( ) ,
16761730 } ,
16771731 } ) ,
1678- "127.0.0.1:0" ,
1732+ Some ( "127.0.0.1:0" ) ,
16791733 target_addr. to_string ( ) ,
16801734 AttestationGenerator :: new ( AttestationType :: DcapTdx , None ) . unwrap ( ) ,
16811735 AttestationVerifier :: expect_none ( ) ,
@@ -1754,7 +1808,7 @@ mod tests {
17541808 certificate_name : certificate_identity_from_chain ( & cert_chain) . unwrap ( ) ,
17551809 } ,
17561810 } ) ,
1757- "127.0.0.1:0" ,
1811+ Some ( "127.0.0.1:0" ) ,
17581812 target_addr. to_string ( ) ,
17591813 AttestationGenerator :: new ( AttestationType :: DcapTdx , None ) . unwrap ( ) ,
17601814 AttestationVerifier :: expect_none ( ) ,
0 commit comments