Skip to content

Commit a42c81d

Browse files
committed
Make inner and outer session optional and dont use default ports
1 parent b3ebc9c commit a42c81d

4 files changed

Lines changed: 160 additions & 66 deletions

File tree

src/attested_get.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ mod tests {
8585
certificate_name: "localhost".to_string(),
8686
},
8787
}),
88-
"127.0.0.1:0",
88+
Some("127.0.0.1:0"),
8989
target_addr.to_string(),
9090
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
9191
AttestationVerifier::expect_none(),

src/file_server.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,21 @@ use tower_http::services::ServeDir;
1111
pub async fn attested_file_server(
1212
path_to_serve: PathBuf,
1313
outer_cert_and_key: Option<TlsCertAndKey>,
14-
outer_listen_addr: impl ToSocketAddrs,
15-
inner_listen_addr: impl ToSocketAddrs,
14+
outer_listen_addr: Option<impl ToSocketAddrs>,
15+
inner_listen_addr: Option<impl ToSocketAddrs>,
1616
attestation_generator: AttestationGenerator,
1717
attestation_verifier: AttestationVerifier,
1818
client_auth: bool,
1919
) -> Result<(), ProxyError> {
2020
let target_addr = static_file_server(path_to_serve).await?;
2121

2222
let server = ProxyServer::new(
23-
outer_cert_and_key.map(|cert_and_key| OuterTlsConfig {
24-
listen_addr: outer_listen_addr,
25-
tls: OuterTlsMode::CertAndKey(cert_and_key),
26-
}),
23+
outer_cert_and_key
24+
.zip(outer_listen_addr)
25+
.map(|(cert_and_key, listen_addr)| OuterTlsConfig {
26+
listen_addr,
27+
tls: OuterTlsMode::CertAndKey(cert_and_key),
28+
}),
2729
inner_listen_addr,
2830
target_addr.to_string(),
2931
attestation_generator,
@@ -113,7 +115,7 @@ mod tests {
113115
certificate_name: "localhost".to_string(),
114116
},
115117
}),
116-
"127.0.0.1:0",
118+
Some("127.0.0.1:0"),
117119
target_addr.to_string(),
118120
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
119121
AttestationVerifier::expect_none(),

src/lib.rs

Lines changed: 100 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ type RequestWithResponseSender = (
5252
oneshot::Sender<Result<Response<BoxBody<bytes::Bytes, hyper::Error>>, hyper::Error>>,
5353
);
5454

55+
type OuterProxySession = (Arc<TcpListener>, NestingTlsAcceptor);
56+
type InnerProxySession = (Arc<TcpListener>, TlsAcceptor);
57+
5558
/// TLS Credentials
5659
pub struct TlsCertAndKey {
5760
/// Der-encoded TLS certificate chain
@@ -207,27 +210,30 @@ pub async fn get_inner_tls_cert_with_config(
207210

208211
/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address
209212
pub struct ProxyServer {
210-
outer_listener: Option<Arc<TcpListener>>,
211-
outer_tls_acceptor: Option<NestingTlsAcceptor>,
212-
inner_listener: Arc<TcpListener>,
213-
inner_tls_acceptor: TlsAcceptor,
213+
outer: Option<OuterProxySession>,
214+
inner: Option<InnerProxySession>,
214215
/// The address/hostname of the target service we are proxying to
215216
target: String,
216217
}
217218

218219
impl ProxyServer {
219220
/// Start with dual listeners. The outer nested-TLS listener is optional.
220-
pub async fn new<O>(
221+
pub async fn new<O, I>(
221222
outer_session: Option<OuterTlsConfig<O>>,
222-
inner_local: impl ToSocketAddrs,
223+
inner_local: Option<I>,
223224
target: String,
224225
attestation_generator: AttestationGenerator,
225226
attestation_verifier: AttestationVerifier,
226227
client_auth: bool,
227228
) -> Result<Self, ProxyError>
228229
where
229230
O: ToSocketAddrs,
231+
I: ToSocketAddrs,
230232
{
233+
if outer_session.is_none() && inner_local.is_none() {
234+
return Err(ProxyError::NoListenersConfigured);
235+
}
236+
231237
let certificate_name = outer_session
232238
.as_ref()
233239
.map(OuterTlsConfig::certificate_name)
@@ -241,24 +247,28 @@ impl ProxyServer {
241247
)
242248
.await?,
243249
);
244-
let inner_listener = Arc::new(TcpListener::bind(inner_local).await?);
245-
let inner_tls_acceptor = TlsAcceptor::from(inner_server_config.clone());
250+
let inner = match inner_local {
251+
Some(inner_local) => {
252+
let inner_listener = Arc::new(TcpListener::bind(inner_local).await?);
253+
let inner_tls_acceptor = TlsAcceptor::from(inner_server_config.clone());
254+
Some((inner_listener, inner_tls_acceptor))
255+
}
256+
None => None,
257+
};
246258

247-
let (outer_listener, outer_tls_acceptor) = match outer_session {
259+
let outer = match outer_session {
248260
Some(outer_session) => {
249261
let (outer_listener, outer_tls_acceptor) = outer_session
250262
.into_listener_and_acceptor(inner_server_config.clone(), client_auth)
251263
.await?;
252-
(Some(outer_listener), Some(outer_tls_acceptor))
264+
Some((outer_listener, outer_tls_acceptor))
253265
}
254-
None => (None, None),
266+
None => None,
255267
};
256268

257269
Ok(Self {
258-
outer_listener,
259-
outer_tls_acceptor,
260-
inner_listener,
261-
inner_tls_acceptor,
270+
outer,
271+
inner,
262272
target,
263273
})
264274
}
@@ -268,13 +278,14 @@ impl ProxyServer {
268278
/// Returns the handle for the task handling the connection
269279
pub async fn accept(&self) -> Result<tokio::task::JoinHandle<()>, ProxyError> {
270280
let target = self.target.clone();
271-
let outer_listener = self.outer_listener.clone();
272-
let outer_tls_acceptor = self.outer_tls_acceptor.clone();
273-
let inner_listener = self.inner_listener.clone();
274-
let inner_tls_acceptor = self.inner_tls_acceptor.clone();
275-
276-
let join_handle = match (outer_listener, outer_tls_acceptor) {
277-
(Some(outer_listener), Some(outer_tls_acceptor)) => {
281+
let outer = self.outer.clone();
282+
let inner = self.inner.clone();
283+
284+
let join_handle = match (outer, inner) {
285+
(
286+
Some((outer_listener, outer_tls_acceptor)),
287+
Some((inner_listener, inner_tls_acceptor)),
288+
) => {
278289
let ((inbound, client_addr), use_outer) = tokio::select! {
279290
accepted = outer_listener.accept() => (accepted?, true),
280291
accepted = inner_listener.accept() => (accepted?, false),
@@ -312,7 +323,7 @@ impl ProxyServer {
312323
}
313324
})
314325
}
315-
_ => {
326+
(None, Some((inner_listener, inner_tls_acceptor))) => {
316327
let (inbound, client_addr) = inner_listener.accept().await?;
317328
tokio::spawn(async move {
318329
match inner_tls_acceptor.accept(inbound).await {
@@ -329,28 +340,54 @@ impl ProxyServer {
329340
}
330341
})
331342
}
343+
(Some((outer_listener, outer_tls_acceptor)), None) => {
344+
let (inbound, client_addr) = outer_listener.accept().await?;
345+
tokio::spawn(async move {
346+
match outer_tls_acceptor.accept(inbound).await {
347+
Ok(tls_stream) => {
348+
if let Err(err) =
349+
Self::handle_outer_connection(tls_stream, target, client_addr).await
350+
{
351+
warn!("Failed to handle outer connection: {err}");
352+
}
353+
}
354+
Err(err) => {
355+
warn!("Outer attestation exchange failed: {err}");
356+
}
357+
}
358+
})
359+
}
360+
_ => return Err(ProxyError::NoListenersConfigured),
332361
};
333362

334363
Ok(join_handle)
335364
}
336365

337366
/// Helper to get the socket address of the underlying TCP listener
338367
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
339-
match &self.outer_listener {
340-
Some(listener) => listener.local_addr(),
341-
None => self.inner_listener.local_addr(),
368+
match &self.outer {
369+
Some((listener, _)) => listener.local_addr(),
370+
None => self
371+
.inner
372+
.as_ref()
373+
.map(|(listener, _)| listener)
374+
.ok_or_else(|| std::io::Error::other("no listeners configured"))?
375+
.local_addr(),
342376
}
343377
}
344378

345379
pub fn outer_local_addr(&self) -> std::io::Result<Option<SocketAddr>> {
346-
self.outer_listener
380+
self.outer
347381
.as_ref()
348-
.map(|listener| listener.local_addr())
382+
.map(|(listener, _)| listener.local_addr())
349383
.transpose()
350384
}
351385

352-
pub fn inner_local_addr(&self) -> std::io::Result<SocketAddr> {
353-
self.inner_listener.local_addr()
386+
pub fn inner_local_addr(&self) -> std::io::Result<Option<SocketAddr>> {
387+
self.inner
388+
.as_ref()
389+
.map(|(listener, _)| listener.local_addr())
390+
.transpose()
354391
}
355392

356393
async fn handle_outer_connection(
@@ -909,6 +946,8 @@ pub enum ProxyError {
909946
MpscSend,
910947
#[error("Client auth must be configured on both the inner and outer TLS sessions")]
911948
ClientAuthMisconfigured,
949+
#[error("At least one server listener must be configured")]
950+
NoListenersConfigured,
912951
}
913952

914953
impl From<mpsc::error::SendError<RequestWithResponseSender>> for ProxyError {
@@ -1039,6 +1078,21 @@ mod tests {
10391078
assert_eq!(protocols, vec![ALPN_HTTP11.to_vec(), ALPN_H2.to_vec()]);
10401079
}
10411080

1081+
#[tokio::test(flavor = "multi_thread")]
1082+
async fn proxy_server_requires_at_least_one_listener() {
1083+
let result = ProxyServer::new(
1084+
None::<OuterTlsConfig<&str>>,
1085+
None::<&str>,
1086+
"127.0.0.1:1".to_string(),
1087+
AttestationGenerator::with_no_attestation(),
1088+
AttestationVerifier::expect_none(),
1089+
false,
1090+
)
1091+
.await;
1092+
1093+
assert!(matches!(result, Err(ProxyError::NoListenersConfigured)));
1094+
}
1095+
10421096
#[tokio::test(flavor = "multi_thread")]
10431097
async fn dual_listener_server_reports_expected_addresses() {
10441098
let target_addr = example_http_service().await;
@@ -1054,7 +1108,7 @@ mod tests {
10541108
listen_addr: "127.0.0.1:0",
10551109
tls: OuterTlsMode::CertAndKey(tls_cert_and_key),
10561110
}),
1057-
"127.0.0.1:0",
1111+
Some("127.0.0.1:0"),
10581112
target_addr.to_string(),
10591113
AttestationGenerator::with_no_attestation(),
10601114
AttestationVerifier::expect_none(),
@@ -1064,13 +1118,13 @@ mod tests {
10641118
.unwrap();
10651119

10661120
let outer_addr = dual_listener_server.outer_local_addr().unwrap().unwrap();
1067-
let inner_addr = dual_listener_server.inner_local_addr().unwrap();
1121+
let inner_addr = dual_listener_server.inner_local_addr().unwrap().unwrap();
10681122
assert_eq!(dual_listener_server.local_addr().unwrap(), outer_addr);
10691123
assert_ne!(outer_addr, inner_addr);
10701124

10711125
let inner_only_server = ProxyServer::new(
10721126
None::<OuterTlsConfig<&str>>,
1073-
"127.0.0.1:0",
1127+
Some("127.0.0.1:0"),
10741128
target_addr.to_string(),
10751129
AttestationGenerator::with_no_attestation(),
10761130
AttestationVerifier::expect_none(),
@@ -1079,7 +1133,7 @@ mod tests {
10791133
.await
10801134
.unwrap();
10811135

1082-
let inner_only_addr = inner_only_server.inner_local_addr().unwrap();
1136+
let inner_only_addr = inner_only_server.inner_local_addr().unwrap().unwrap();
10831137
assert!(inner_only_server.outer_local_addr().unwrap().is_none());
10841138
assert_eq!(inner_only_server.local_addr().unwrap(), inner_only_addr);
10851139
}
@@ -1091,7 +1145,7 @@ mod tests {
10911145

10921146
let proxy_server = ProxyServer::new(
10931147
None::<OuterTlsConfig<&str>>,
1094-
"127.0.0.1:0",
1148+
Some("127.0.0.1:0"),
10951149
target_addr.to_string(),
10961150
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
10971151
AttestationVerifier::expect_none(),
@@ -1100,7 +1154,7 @@ mod tests {
11001154
.await
11011155
.unwrap();
11021156

1103-
let inner_addr = proxy_server.inner_local_addr().unwrap();
1157+
let inner_addr = proxy_server.inner_local_addr().unwrap().unwrap();
11041158

11051159
tokio::spawn(async move {
11061160
proxy_server.accept().await.unwrap();
@@ -1147,7 +1201,7 @@ mod tests {
11471201
certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(),
11481202
},
11491203
}),
1150-
"127.0.0.1:0",
1204+
Some("127.0.0.1:0"),
11511205
target_addr.to_string(),
11521206
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
11531207
AttestationVerifier::expect_none(),
@@ -1254,7 +1308,7 @@ mod tests {
12541308
certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(),
12551309
},
12561310
}),
1257-
"127.0.0.1:0",
1311+
Some("127.0.0.1:0"),
12581312
target_addr.to_string(),
12591313
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
12601314
AttestationVerifier::expect_none(),
@@ -1322,7 +1376,7 @@ mod tests {
13221376
certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(),
13231377
},
13241378
}),
1325-
"127.0.0.1:0",
1379+
Some("127.0.0.1:0"),
13261380
target_addr.to_string(),
13271381
AttestationGenerator::with_no_attestation(),
13281382
AttestationVerifier::mock(),
@@ -1380,7 +1434,7 @@ mod tests {
13801434
certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(),
13811435
},
13821436
}),
1383-
"127.0.0.1:0",
1437+
Some("127.0.0.1:0"),
13841438
target_addr.to_string(),
13851439
AttestationGenerator::with_no_attestation(),
13861440
AttestationVerifier::mock(),
@@ -1450,7 +1504,7 @@ mod tests {
14501504
certificate_name: certificate_identity_from_chain(&server_cert_chain).unwrap(),
14511505
},
14521506
}),
1453-
"127.0.0.1:0",
1507+
Some("127.0.0.1:0"),
14541508
target_addr.to_string(),
14551509
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
14561510
AttestationVerifier::mock(),
@@ -1510,7 +1564,7 @@ mod tests {
15101564
certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(),
15111565
},
15121566
}),
1513-
"127.0.0.1:0",
1567+
Some("127.0.0.1:0"),
15141568
target_addr.to_string(),
15151569
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
15161570
AttestationVerifier::expect_none(),
@@ -1558,7 +1612,7 @@ mod tests {
15581612
certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(),
15591613
},
15601614
}),
1561-
"127.0.0.1:0",
1615+
Some("127.0.0.1:0"),
15621616
target_addr.to_string(),
15631617
AttestationGenerator::with_no_attestation(),
15641618
AttestationVerifier::expect_none(),
@@ -1604,7 +1658,7 @@ mod tests {
16041658
certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(),
16051659
},
16061660
}),
1607-
"127.0.0.1:0",
1661+
Some("127.0.0.1:0"),
16081662
target_addr.to_string(),
16091663
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
16101664
AttestationVerifier::expect_none(),
@@ -1675,7 +1729,7 @@ mod tests {
16751729
certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(),
16761730
},
16771731
}),
1678-
"127.0.0.1:0",
1732+
Some("127.0.0.1:0"),
16791733
target_addr.to_string(),
16801734
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
16811735
AttestationVerifier::expect_none(),
@@ -1754,7 +1808,7 @@ mod tests {
17541808
certificate_name: certificate_identity_from_chain(&cert_chain).unwrap(),
17551809
},
17561810
}),
1757-
"127.0.0.1:0",
1811+
Some("127.0.0.1:0"),
17581812
target_addr.to_string(),
17591813
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
17601814
AttestationVerifier::expect_none(),

0 commit comments

Comments
 (0)