Skip to content

Commit 838ad3d

Browse files
authored
Merge pull request #142 from flashbots/peg/fixes-for-get-tls-cert
Add json measurement output option and generic TDX attestation type to make CLI similar to `cvm-reverse-proxy`
2 parents d36e549 + f8d7460 commit 838ad3d

4 files changed

Lines changed: 59 additions & 17 deletions

File tree

attested-tls/src/attestation/measurements.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,18 @@ impl MeasurementPolicy {
299299
}
300300
}
301301

302+
/// Accept any TDX attestation regardless of platform
303+
pub fn tdx() -> Self {
304+
Self {
305+
accepted_measurements: vec![
306+
MeasurementRecord::allow_any_measurement(AttestationType::DcapTdx),
307+
MeasurementRecord::allow_any_measurement(AttestationType::QemuTdx),
308+
MeasurementRecord::allow_any_measurement(AttestationType::GcpTdx),
309+
MeasurementRecord::allow_any_measurement(AttestationType::AzureTdx),
310+
],
311+
}
312+
}
313+
302314
/// Expect mock measurements used in tests
303315
#[cfg(any(test, feature = "mock"))]
304316
pub fn mock() -> Self {

attested-tls/src/lib.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,9 @@ impl AttestedTlsClient {
419419
pub async fn get_tls_cert(
420420
&self,
421421
server_name: &str,
422-
) -> Result<Vec<CertificateDer<'static>>, AttestedTlsError> {
423-
let (mut tls_stream, _, _) = self.connect_tcp(server_name).await?;
422+
) -> Result<(Vec<CertificateDer<'static>>, Option<MultiMeasurements>), AttestedTlsError> {
423+
let (mut tls_stream, measurements, _attestation_type) =
424+
self.connect_tcp(server_name).await?;
424425

425426
let (_io, server_connection) = tls_stream.get_ref();
426427

@@ -431,7 +432,7 @@ impl AttestedTlsClient {
431432

432433
tls_stream.shutdown().await?;
433434

434-
Ok(remote_cert_chain)
435+
Ok((remote_cert_chain, measurements))
435436
}
436437
}
437438

@@ -440,7 +441,7 @@ pub async fn get_tls_cert(
440441
server_name: String,
441442
attestation_verifier: AttestationVerifier,
442443
remote_certificate: Option<CertificateDer<'static>>,
443-
) -> Result<Vec<CertificateDer<'static>>, AttestedTlsError> {
444+
) -> Result<(Vec<CertificateDer<'static>>, Option<MultiMeasurements>), AttestedTlsError> {
444445
tracing::debug!("Getting remote TLS cert");
445446
let attested_tls_client = AttestedTlsClient::new(
446447
None,
@@ -458,7 +459,7 @@ pub async fn get_tls_cert_with_config(
458459
server_name: &str,
459460
attestation_verifier: AttestationVerifier,
460461
client_config: ClientConfig,
461-
) -> Result<Vec<CertificateDer<'static>>, AttestedTlsError> {
462+
) -> Result<(Vec<CertificateDer<'static>>, Option<MultiMeasurements>), AttestedTlsError> {
462463
let attested_tls_client = AttestedTlsClient::new_with_tls_config(
463464
client_config,
464465
AttestationGenerator::with_no_attestation(),

src/lib.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,21 @@ pub async fn get_tls_cert(
6565
attestation_verifier: AttestationVerifier,
6666
remote_certificate: Option<CertificateDer<'static>>,
6767
allow_self_signed: bool,
68-
) -> Result<Vec<CertificateDer<'static>>, AttestedTlsError> {
69-
if allow_self_signed {
68+
) -> Result<(Vec<CertificateDer<'static>>, Option<MultiMeasurements>), AttestedTlsError> {
69+
let (cert, measurements) = if allow_self_signed {
7070
let client_tls_config = self_signed::client_tls_config_allow_self_signed()?;
7171
attested_tls::get_tls_cert_with_config(
7272
&server_name,
7373
attestation_verifier,
7474
client_tls_config,
7575
)
76-
.await
76+
.await?
7777
} else {
78-
attested_tls::get_tls_cert(server_name, attestation_verifier, remote_certificate).await
79-
}
78+
attested_tls::get_tls_cert(server_name, attestation_verifier, remote_certificate).await?
79+
};
80+
81+
debug!("[get-tls-cert] Connected to proxy server with measurements: {measurements:?}");
82+
Ok((cert, measurements))
8083
}
8184

8285
/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address
@@ -1114,7 +1117,7 @@ mod tests {
11141117
proxy_server.accept().await.unwrap();
11151118
});
11161119

1117-
let retrieved_chain = get_tls_cert_with_config(
1120+
let (retrieved_chain, _measurements) = get_tls_cert_with_config(
11181121
&proxy_server_addr.to_string(),
11191122
AttestationVerifier::mock(),
11201123
client_config,

src/main.rs

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use anyhow::{anyhow, ensure};
2+
use attested_tls::attestation::measurements::MultiMeasurements;
23
use clap::{Parser, Subcommand};
34
use std::{
45
fs::File,
@@ -126,6 +127,9 @@ enum CliCommand {
126127
/// Enables verification of self-signed TLS certificates
127128
#[arg(long)]
128129
allow_self_signed: bool,
130+
/// Filename to write measurements as JSON to
131+
#[arg(long)]
132+
out_measurements: Option<PathBuf>,
129133
},
130134
/// Serve a filesystem path over an attested channel
131135
AttestedFileServer {
@@ -201,12 +205,22 @@ async fn main() -> anyhow::Result<()> {
201205
MeasurementPolicy::from_file_or_url(server_measurements).await?
202206
}
203207
None => {
204-
let allowed_server_attestation_type: AttestationType = serde_json::from_value(
205-
serde_json::Value::String(cli.allowed_remote_attestation_type.ok_or(anyhow!(
208+
match cli
209+
.allowed_remote_attestation_type
210+
.ok_or(anyhow!(
206211
"Either a measurements file or an allowed attestation type must be provided"
207-
))?),
208-
)?;
209-
MeasurementPolicy::single_attestation_type(allowed_server_attestation_type)
212+
))?
213+
.to_lowercase()
214+
.as_str()
215+
{
216+
"tdx" => MeasurementPolicy::tdx(),
217+
attestation_type => {
218+
let allowed_server_attestation_type: AttestationType = serde_json::from_value(
219+
serde_json::Value::String(attestation_type.to_string()),
220+
)?;
221+
MeasurementPolicy::single_attestation_type(allowed_server_attestation_type)
222+
}
223+
}
210224
}
211225
};
212226

@@ -340,6 +354,7 @@ async fn main() -> anyhow::Result<()> {
340354
server,
341355
tls_ca_certificate,
342356
allow_self_signed,
357+
out_measurements,
343358
} => {
344359
let remote_tls_cert = match tls_ca_certificate {
345360
Some(remote_cert_filename) => Some(
@@ -350,13 +365,24 @@ async fn main() -> anyhow::Result<()> {
350365
),
351366
None => None,
352367
};
353-
let cert_chain = get_tls_cert(
368+
let (cert_chain, measurements) = get_tls_cert(
354369
server,
355370
attestation_verifier,
356371
remote_tls_cert,
357372
allow_self_signed,
358373
)
359374
.await?;
375+
376+
// If the user chose to write measurements to a file as JSON
377+
if let Some(path_to_write_measurements) = out_measurements {
378+
std::fs::write(
379+
path_to_write_measurements,
380+
measurements
381+
.unwrap_or(MultiMeasurements::NoAttestation)
382+
.to_header_format()?
383+
.as_bytes(),
384+
)?;
385+
}
360386
println!("{}", certs_to_pem_string(&cert_chain)?);
361387
}
362388
CliCommand::AttestedFileServer {

0 commit comments

Comments
 (0)