Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 283 additions & 8 deletions codex-rs/code-bridge-service/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@ use axum::extract::Path as AxumPath;
use axum::extract::Request;
use axum::extract::State;
use axum::http::HeaderMap;
use axum::http::HeaderValue;
use axum::http::Method;
use axum::http::StatusCode;
use axum::http::header::ACCESS_CONTROL_ALLOW_HEADERS;
use axum::http::header::ACCESS_CONTROL_ALLOW_METHODS;
use axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN;
use axum::http::header::ACCESS_CONTROL_MAX_AGE;
use axum::http::header::AUTHORIZATION;
use axum::http::header::ORIGIN;
use axum::http::header::VARY;
use axum::middleware;
use axum::middleware::Next;
use axum::response::IntoResponse;
Expand Down Expand Up @@ -54,6 +62,7 @@ use std::convert::Infallible;
use std::fs::OpenOptions;
use std::io;
use std::io::Write;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::SocketAddr;
use std::path::Path;
Expand Down Expand Up @@ -81,6 +90,12 @@ const MAX_PENDING_REQUESTS: usize = 256;
const MAX_RETAINED_DELIVERY_BYTES: usize = 8 * 1024 * 1024;
const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(2);
const EVENT_WAKE_CHANNEL_CAPACITY: usize = 64;
const CORS_ALLOW_HEADERS: &str =
"authorization, content-type, last-event-id, x-code-bridge-client-session";
const CORS_ALLOW_METHODS: &str = "GET, POST, OPTIONS";
const CORS_MAX_AGE_SECONDS: &str = "600";
const CORS_VARY_HEADERS: &str =
"origin, access-control-request-method, access-control-request-headers";

#[derive(Debug, Clone)]
pub struct BridgeServiceConfig {
Expand Down Expand Up @@ -1080,6 +1095,8 @@ async fn events_handler(
AxumPath(client_id): AxumPath<String>,
headers: HeaderMap,
) -> Result<impl IntoResponse, BridgeHttpError> {
// Browser clients need a fetch-based SSE reader here: native EventSource
// cannot attach the bearer and client-session headers this endpoint uses.
let client_session_token = client_session_token_from_headers(&headers)?;
let last_seen_sequence = headers
.get("last-event-id")
Expand Down Expand Up @@ -1206,10 +1223,18 @@ async fn require_auth(
"Code Bridge service only accepts loopback clients".to_string(),
));
}
let token = bearer_token_from_headers(request.headers())?;
let cors_origin = cors_origin_from_headers(request.headers());
if request.method() == Method::OPTIONS {
return Ok(preflight_response(cors_origin));
}
let token = match bearer_token_from_headers(request.headers()) {
Ok(token) => token,
Err(error) => return Ok(cors_error_response(error, cors_origin)),
};
if !constant_time_eq(token.as_bytes(), state.auth_secret.as_bytes()) {
return Err(BridgeHttpError::Unauthorized(
"invalid Code Bridge bearer token".to_string(),
return Ok(cors_error_response(
BridgeHttpError::Unauthorized("invalid Code Bridge bearer token".to_string()),
cors_origin,
));
}
if let Some(content_length) = request
Expand All @@ -1219,12 +1244,84 @@ async fn require_auth(
.and_then(|value| value.parse::<usize>().ok())
&& content_length > MAX_SCREENSHOT_MESSAGE_BYTES
{
return Err(BridgeHttpError::PayloadTooLarge {
limit: MAX_SCREENSHOT_MESSAGE_BYTES,
actual: content_length,
});
return Ok(cors_error_response(
BridgeHttpError::PayloadTooLarge {
limit: MAX_SCREENSHOT_MESSAGE_BYTES,
actual: content_length,
},
cors_origin,
));
}
let mut response = next.run(request).await;
add_cors_headers(response.headers_mut(), cors_origin);
Ok(response)
}

fn cors_error_response(error: BridgeHttpError, cors_origin: Option<HeaderValue>) -> Response {
let mut response = error.into_response();
add_cors_headers(response.headers_mut(), cors_origin);
response
}

fn preflight_response(cors_origin: Option<HeaderValue>) -> Response {
let mut response = StatusCode::NO_CONTENT.into_response();
add_cors_headers(response.headers_mut(), cors_origin);
response
}

fn add_cors_headers(headers: &mut HeaderMap, cors_origin: Option<HeaderValue>) {
if let Some(origin) = cors_origin {
headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
headers.insert(
ACCESS_CONTROL_ALLOW_METHODS,
HeaderValue::from_static(CORS_ALLOW_METHODS),
);
headers.insert(
ACCESS_CONTROL_ALLOW_HEADERS,
HeaderValue::from_static(CORS_ALLOW_HEADERS),
);
headers.insert(
ACCESS_CONTROL_MAX_AGE,
HeaderValue::from_static(CORS_MAX_AGE_SECONDS),
);
}
headers.insert(VARY, HeaderValue::from_static(CORS_VARY_HEADERS));
}

fn cors_origin_from_headers(headers: &HeaderMap) -> Option<HeaderValue> {
let origin = headers.get(ORIGIN)?;
if is_loopback_origin(origin) {
Some(origin.clone())
} else {
None
}
}

fn is_loopback_origin(origin: &HeaderValue) -> bool {
let Ok(origin) = origin.to_str() else {
return false;
};
let Ok(uri) = origin.parse::<http::Uri>() else {
return false;
};
if !matches!(uri.scheme_str(), Some("http") | Some("https")) {
return false;
}
if uri
.path_and_query()
.is_some_and(|path| path.as_str() != "/")
{
return false;
}
Ok(next.run(request).await)
let Some(host) = uri.host() else {
return false;
};
let host = host.trim_matches(['[', ']']);
host.eq_ignore_ascii_case("localhost")
|| host
.parse::<IpAddr>()
.map(|ip| ip.is_loopback())
.unwrap_or(false)
}

fn bearer_token_from_headers(headers: &HeaderMap) -> Result<&str, BridgeHttpError> {
Expand Down Expand Up @@ -1610,11 +1707,16 @@ mod tests {

let missing = client
.post(&url)
.header(ORIGIN, "http://127.0.0.1:5173")
.body("not json")
.send()
.await
.expect("missing auth response");
assert_eq!(missing.status(), StatusCode::UNAUTHORIZED);
assert_eq!(
missing.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN),
Some(&HeaderValue::from_static("http://127.0.0.1:5173"))
);

let invalid = client
.post(&url)
Expand All @@ -1637,6 +1739,179 @@ mod tests {
service.handle.shutdown().await;
}

#[tokio::test]
async fn accepts_loopback_browser_preflight_without_bearer_auth() {
let service = start_test_service(Duration::from_secs(30), Duration::from_secs(30)).await;
let client = Client::new();
let url = format!("{}/message", service.handle.endpoint_url());

let response = client
.request(Method::OPTIONS, &url)
.header(ORIGIN, "http://127.0.0.1:5173")
.header("access-control-request-method", "POST")
.header(
"access-control-request-headers",
"authorization, content-type, x-code-bridge-client-session",
)
.send()
.await
.expect("preflight response");

assert_eq!(response.status(), StatusCode::NO_CONTENT);
assert_eq!(
response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN),
Some(&HeaderValue::from_static("http://127.0.0.1:5173"))
);
assert_eq!(
response.headers().get(ACCESS_CONTROL_ALLOW_METHODS),
Some(&HeaderValue::from_static(CORS_ALLOW_METHODS))
);
assert_eq!(
response.headers().get(ACCESS_CONTROL_ALLOW_HEADERS),
Some(&HeaderValue::from_static(CORS_ALLOW_HEADERS))
);

service.handle.shutdown().await;
}

#[tokio::test]
async fn accepts_loopback_event_stream_preflight_without_bearer_auth() {
let service = start_test_service(Duration::from_secs(30), Duration::from_secs(30)).await;
let client = Client::new();
let url = format!("{}/events/browser-client-1", service.handle.endpoint_url());

let response = client
.request(Method::OPTIONS, &url)
.header(ORIGIN, "http://[::1]:5173")
.header("access-control-request-method", "GET")
.header(
"access-control-request-headers",
"authorization, last-event-id, x-code-bridge-client-session",
)
.send()
.await
.expect("preflight response");

assert_eq!(response.status(), StatusCode::NO_CONTENT);
assert_eq!(
response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN),
Some(&HeaderValue::from_static("http://[::1]:5173"))
);
assert_eq!(
response.headers().get(ACCESS_CONTROL_ALLOW_METHODS),
Some(&HeaderValue::from_static(CORS_ALLOW_METHODS))
);

service.handle.shutdown().await;
}

#[tokio::test]
async fn includes_loopback_cors_headers_on_authenticated_event_streams() {
let service = start_test_service(Duration::from_secs(30), Duration::from_secs(30)).await;
let client = Client::new();
let subscriber = register_subscriber(
&client,
&service.handle,
"browser-subscriber-1",
subscriber_capabilities(),
)
.await;

let response = client
.get(format!(
"{}/events/browser-subscriber-1",
service.handle.endpoint_url()
))
.bearer_auth(service.handle.auth_secret())
.header(CLIENT_SESSION_HEADER, subscriber.session_token.as_str())
.header("last-event-id", "0")
.header(ORIGIN, "http://127.0.0.1:5173")
.send()
.await
.expect("event stream response");

assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN),
Some(&HeaderValue::from_static("http://127.0.0.1:5173"))
);
assert_eq!(
response.headers().get(ACCESS_CONTROL_ALLOW_HEADERS),
Some(&HeaderValue::from_static(CORS_ALLOW_HEADERS))
);
drop(response);

service.handle.shutdown().await;
}

#[tokio::test]
async fn omits_cors_allow_origin_for_non_loopback_browser_origins() {
let service = start_test_service(Duration::from_secs(30), Duration::from_secs(30)).await;
let client = Client::new();
let url = format!("{}/message", service.handle.endpoint_url());

let response = client
.request(Method::OPTIONS, &url)
.header(ORIGIN, "https://example.com")
.header("access-control-request-method", "POST")
.header("access-control-request-headers", "authorization")
.send()
.await
.expect("preflight response");

assert_eq!(response.status(), StatusCode::NO_CONTENT);
assert!(
response
.headers()
.get(ACCESS_CONTROL_ALLOW_ORIGIN)
.is_none()
);
assert_eq!(
response.headers().get(VARY),
Some(&HeaderValue::from_static(CORS_VARY_HEADERS))
);

service.handle.shutdown().await;
}

#[tokio::test]
async fn includes_loopback_cors_headers_on_authenticated_message_responses() {
let service = start_test_service(Duration::from_secs(30), Duration::from_secs(30)).await;
let client = Client::new();
let envelope = hello_envelope(
"browser-client-1",
service.handle.auth_secret(),
ClientRole::Producer,
producer_capabilities(),
);
let response = client
.post(format!("{}/message", service.handle.endpoint_url()))
.bearer_auth(service.handle.auth_secret())
.header(ORIGIN, "http://localhost:3000")
.json(&envelope)
.send()
.await
.expect("message response");

assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN),
Some(&HeaderValue::from_static("http://localhost:3000"))
);
assert_eq!(
response.headers().get(ACCESS_CONTROL_ALLOW_HEADERS),
Some(&HeaderValue::from_static(CORS_ALLOW_HEADERS))
);
let payload = response
.json::<BridgeMessageResponse>()
.await
.expect("message response json")
.payload;
assert!(matches!(payload, BridgePayload::HelloResponse(_)));

service.handle.shutdown().await;
}

#[tokio::test]
async fn events_endpoint_requires_auth_and_registered_client() {
let service = start_test_service(Duration::from_secs(30), Duration::from_secs(30)).await;
Expand Down
Loading