Skip to content

Commit 4650de6

Browse files
committed
Fix ALPN
1 parent 9dfb3f2 commit 4650de6

1 file changed

Lines changed: 58 additions & 8 deletions

File tree

src/lib.rs

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,8 @@ pub async fn get_inner_tls_cert(
9797
pub async fn get_inner_tls_cert_with_config(
9898
server_name: String,
9999
attestation_verifier: AttestationVerifier,
100-
mut outer_client_config: ClientConfig,
100+
outer_client_config: ClientConfig,
101101
) -> Result<Vec<CertificateDer<'static>>, ProxyError> {
102-
ensure_proxy_alpn_protocols(&mut outer_client_config.alpn_protocols);
103102
let outbound_stream = tokio::net::TcpStream::connect(&server_name).await?;
104103

105104
let domain = server_name_from_host(&server_name)?;
@@ -205,19 +204,18 @@ impl ProxyServer {
205204
/// Start with preconfigured TLS and require client auth on both nested sessions
206205
pub async fn new_with_tls_config_and_client_auth(
207206
cert_chain: Vec<CertificateDer<'static>>,
208-
mut outer_server_config: ServerConfig,
207+
outer_server_config: ServerConfig,
209208
local: impl ToSocketAddrs,
210209
target: String,
211210
attestation_generator: AttestationGenerator,
212211
attestation_verifier: AttestationVerifier,
213212
client_auth: bool,
214213
) -> Result<Self, ProxyError> {
215-
ensure_proxy_alpn_protocols(&mut outer_server_config.alpn_protocols);
216214
let server_name = certificate_identity_from_chain(&cert_chain)?;
217215
let inner_cert_resolver =
218216
build_attested_cert_resolver(attestation_generator, &cert_chain).await?;
219217

220-
let inner_server_config = if client_auth {
218+
let mut inner_server_config = if client_auth {
221219
let attested_cert_verifier =
222220
AttestedCertificateVerifier::new(None, attestation_verifier)?;
223221
ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
@@ -230,6 +228,8 @@ impl ProxyServer {
230228
.with_cert_resolver(Arc::new(inner_cert_resolver))
231229
};
232230

231+
ensure_proxy_alpn_protocols(&mut inner_server_config.alpn_protocols);
232+
233233
let nesting_tls_acceptor =
234234
NestingTlsAcceptor::new(Arc::new(outer_server_config), Arc::new(inner_server_config));
235235
let listener = TcpListener::bind(local).await?;
@@ -440,14 +440,13 @@ impl ProxyClient {
440440

441441
/// Create a new proxy client with given TLS configuration
442442
pub async fn new_with_tls_config(
443-
mut outer_client_config: ClientConfig,
443+
outer_client_config: ClientConfig,
444444
address: impl ToSocketAddrs,
445445
target_name: String,
446446
attestation_generator: AttestationGenerator,
447447
attestation_verifier: AttestationVerifier,
448448
cert_chain: Option<Vec<CertificateDer<'static>>>,
449449
) -> Result<Self, ProxyError> {
450-
ensure_proxy_alpn_protocols(&mut outer_client_config.alpn_protocols);
451450
let outer_has_client_auth = outer_client_config.client_auth_cert_resolver.has_certs();
452451
let inner_has_client_auth = cert_chain.is_some();
453452

@@ -457,7 +456,7 @@ impl ProxyClient {
457456

458457
let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?;
459458

460-
let inner_client_config = if let Some(cert_chain) = cert_chain.as_ref() {
459+
let mut inner_client_config = if let Some(cert_chain) = cert_chain.as_ref() {
461460
let inner_cert_resolver =
462461
build_attested_cert_resolver(attestation_generator, cert_chain).await?;
463462
ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
@@ -470,6 +469,7 @@ impl ProxyClient {
470469
.with_custom_certificate_verifier(Arc::new(attested_cert_verifier))
471470
.with_no_client_auth()
472471
};
472+
ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols);
473473

474474
let nesting_tls_connector =
475475
NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config));
@@ -905,6 +905,56 @@ mod tests {
905905
assert_eq!(protocols, vec![ALPN_HTTP11.to_vec(), ALPN_H2.to_vec()]);
906906
}
907907

908+
#[tokio::test(flavor = "multi_thread")]
909+
async fn http_proxy_negotiates_http2_by_default() {
910+
let target_addr = example_http_service().await;
911+
912+
let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost");
913+
let (server_config, outer_client_config) =
914+
generate_tls_config(cert_chain.clone(), private_key);
915+
916+
let proxy_server = ProxyServer::new_with_tls_config(
917+
cert_chain,
918+
server_config,
919+
"127.0.0.1:0",
920+
target_addr.to_string(),
921+
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
922+
AttestationVerifier::expect_none(),
923+
)
924+
.await
925+
.unwrap();
926+
927+
let proxy_addr = proxy_server.local_addr().unwrap();
928+
929+
tokio::spawn(async move {
930+
proxy_server.accept().await.unwrap();
931+
});
932+
933+
let attested_cert_verifier =
934+
AttestedCertificateVerifier::new(None, AttestationVerifier::mock()).unwrap();
935+
let mut inner_client_config =
936+
ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
937+
.dangerous()
938+
.with_custom_certificate_verifier(Arc::new(attested_cert_verifier))
939+
.with_no_client_auth();
940+
ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols);
941+
942+
let nesting_tls_connector = NestingTlsConnector::new(
943+
Arc::new(outer_client_config),
944+
Arc::new(inner_client_config),
945+
);
946+
947+
let (sender, conn) = ProxyClient::setup_connection(
948+
&nesting_tls_connector,
949+
&format!("localhost:{}", proxy_addr.port()),
950+
)
951+
.await
952+
.unwrap();
953+
954+
assert!(matches!(sender, HttpSender::Http2(_)));
955+
assert!(matches!(conn, HttpConnection::Http2 { .. }));
956+
}
957+
908958
#[tokio::test(flavor = "multi_thread")]
909959
async fn http_proxy_default_constructors_work() {
910960
let target_addr = example_http_service().await;

0 commit comments

Comments
 (0)