From 2e50e2e41cf6d84496b74c33fa5a72846a696e9e Mon Sep 17 00:00:00 2001 From: shiny-code-bot Date: Tue, 16 Jun 2026 12:25:25 -0400 Subject: [PATCH] Add loopback CORS support for Code Bridge --- codex-rs/code-bridge-service/src/lib.rs | 291 +++++++++++++++++++++++- 1 file changed, 283 insertions(+), 8 deletions(-) diff --git a/codex-rs/code-bridge-service/src/lib.rs b/codex-rs/code-bridge-service/src/lib.rs index fedf9e9631df..61aadc703adb 100644 --- a/codex-rs/code-bridge-service/src/lib.rs +++ b/codex-rs/code-bridge-service/src/lib.rs @@ -6,8 +6,16 @@ use axum::extract::Path as AxumPath; use axum::extract::Request; use axum::extract::State; use axum::http::HeaderMap; +use axum::http::HeaderValue; +use axum::http::Method; use axum::http::StatusCode; +use axum::http::header::ACCESS_CONTROL_ALLOW_HEADERS; +use axum::http::header::ACCESS_CONTROL_ALLOW_METHODS; +use axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN; +use axum::http::header::ACCESS_CONTROL_MAX_AGE; use axum::http::header::AUTHORIZATION; +use axum::http::header::ORIGIN; +use axum::http::header::VARY; use axum::middleware; use axum::middleware::Next; use axum::response::IntoResponse; @@ -54,6 +62,7 @@ use std::convert::Infallible; use std::fs::OpenOptions; use std::io; use std::io::Write; +use std::net::IpAddr; use std::net::Ipv4Addr; use std::net::SocketAddr; use std::path::Path; @@ -81,6 +90,12 @@ const MAX_PENDING_REQUESTS: usize = 256; const MAX_RETAINED_DELIVERY_BYTES: usize = 8 * 1024 * 1024; const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(2); const EVENT_WAKE_CHANNEL_CAPACITY: usize = 64; +const CORS_ALLOW_HEADERS: &str = + "authorization, content-type, last-event-id, x-code-bridge-client-session"; +const CORS_ALLOW_METHODS: &str = "GET, POST, OPTIONS"; +const CORS_MAX_AGE_SECONDS: &str = "600"; +const CORS_VARY_HEADERS: &str = + "origin, access-control-request-method, access-control-request-headers"; #[derive(Debug, Clone)] pub struct BridgeServiceConfig { @@ -1080,6 +1095,8 @@ async fn events_handler( AxumPath(client_id): AxumPath, headers: HeaderMap, ) -> Result { + // Browser clients need a fetch-based SSE reader here: native EventSource + // cannot attach the bearer and client-session headers this endpoint uses. let client_session_token = client_session_token_from_headers(&headers)?; let last_seen_sequence = headers .get("last-event-id") @@ -1206,10 +1223,18 @@ async fn require_auth( "Code Bridge service only accepts loopback clients".to_string(), )); } - let token = bearer_token_from_headers(request.headers())?; + let cors_origin = cors_origin_from_headers(request.headers()); + if request.method() == Method::OPTIONS { + return Ok(preflight_response(cors_origin)); + } + let token = match bearer_token_from_headers(request.headers()) { + Ok(token) => token, + Err(error) => return Ok(cors_error_response(error, cors_origin)), + }; if !constant_time_eq(token.as_bytes(), state.auth_secret.as_bytes()) { - return Err(BridgeHttpError::Unauthorized( - "invalid Code Bridge bearer token".to_string(), + return Ok(cors_error_response( + BridgeHttpError::Unauthorized("invalid Code Bridge bearer token".to_string()), + cors_origin, )); } if let Some(content_length) = request @@ -1219,12 +1244,84 @@ async fn require_auth( .and_then(|value| value.parse::().ok()) && content_length > MAX_SCREENSHOT_MESSAGE_BYTES { - return Err(BridgeHttpError::PayloadTooLarge { - limit: MAX_SCREENSHOT_MESSAGE_BYTES, - actual: content_length, - }); + return Ok(cors_error_response( + BridgeHttpError::PayloadTooLarge { + limit: MAX_SCREENSHOT_MESSAGE_BYTES, + actual: content_length, + }, + cors_origin, + )); + } + let mut response = next.run(request).await; + add_cors_headers(response.headers_mut(), cors_origin); + Ok(response) +} + +fn cors_error_response(error: BridgeHttpError, cors_origin: Option) -> Response { + let mut response = error.into_response(); + add_cors_headers(response.headers_mut(), cors_origin); + response +} + +fn preflight_response(cors_origin: Option) -> Response { + let mut response = StatusCode::NO_CONTENT.into_response(); + add_cors_headers(response.headers_mut(), cors_origin); + response +} + +fn add_cors_headers(headers: &mut HeaderMap, cors_origin: Option) { + if let Some(origin) = cors_origin { + headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin); + headers.insert( + ACCESS_CONTROL_ALLOW_METHODS, + HeaderValue::from_static(CORS_ALLOW_METHODS), + ); + headers.insert( + ACCESS_CONTROL_ALLOW_HEADERS, + HeaderValue::from_static(CORS_ALLOW_HEADERS), + ); + headers.insert( + ACCESS_CONTROL_MAX_AGE, + HeaderValue::from_static(CORS_MAX_AGE_SECONDS), + ); + } + headers.insert(VARY, HeaderValue::from_static(CORS_VARY_HEADERS)); +} + +fn cors_origin_from_headers(headers: &HeaderMap) -> Option { + let origin = headers.get(ORIGIN)?; + if is_loopback_origin(origin) { + Some(origin.clone()) + } else { + None + } +} + +fn is_loopback_origin(origin: &HeaderValue) -> bool { + let Ok(origin) = origin.to_str() else { + return false; + }; + let Ok(uri) = origin.parse::() else { + return false; + }; + if !matches!(uri.scheme_str(), Some("http") | Some("https")) { + return false; + } + if uri + .path_and_query() + .is_some_and(|path| path.as_str() != "/") + { + return false; } - Ok(next.run(request).await) + let Some(host) = uri.host() else { + return false; + }; + let host = host.trim_matches(['[', ']']); + host.eq_ignore_ascii_case("localhost") + || host + .parse::() + .map(|ip| ip.is_loopback()) + .unwrap_or(false) } fn bearer_token_from_headers(headers: &HeaderMap) -> Result<&str, BridgeHttpError> { @@ -1610,11 +1707,16 @@ mod tests { let missing = client .post(&url) + .header(ORIGIN, "http://127.0.0.1:5173") .body("not json") .send() .await .expect("missing auth response"); assert_eq!(missing.status(), StatusCode::UNAUTHORIZED); + assert_eq!( + missing.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), + Some(&HeaderValue::from_static("http://127.0.0.1:5173")) + ); let invalid = client .post(&url) @@ -1637,6 +1739,179 @@ mod tests { service.handle.shutdown().await; } + #[tokio::test] + async fn accepts_loopback_browser_preflight_without_bearer_auth() { + let service = start_test_service(Duration::from_secs(30), Duration::from_secs(30)).await; + let client = Client::new(); + let url = format!("{}/message", service.handle.endpoint_url()); + + let response = client + .request(Method::OPTIONS, &url) + .header(ORIGIN, "http://127.0.0.1:5173") + .header("access-control-request-method", "POST") + .header( + "access-control-request-headers", + "authorization, content-type, x-code-bridge-client-session", + ) + .send() + .await + .expect("preflight response"); + + assert_eq!(response.status(), StatusCode::NO_CONTENT); + assert_eq!( + response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), + Some(&HeaderValue::from_static("http://127.0.0.1:5173")) + ); + assert_eq!( + response.headers().get(ACCESS_CONTROL_ALLOW_METHODS), + Some(&HeaderValue::from_static(CORS_ALLOW_METHODS)) + ); + assert_eq!( + response.headers().get(ACCESS_CONTROL_ALLOW_HEADERS), + Some(&HeaderValue::from_static(CORS_ALLOW_HEADERS)) + ); + + service.handle.shutdown().await; + } + + #[tokio::test] + async fn accepts_loopback_event_stream_preflight_without_bearer_auth() { + let service = start_test_service(Duration::from_secs(30), Duration::from_secs(30)).await; + let client = Client::new(); + let url = format!("{}/events/browser-client-1", service.handle.endpoint_url()); + + let response = client + .request(Method::OPTIONS, &url) + .header(ORIGIN, "http://[::1]:5173") + .header("access-control-request-method", "GET") + .header( + "access-control-request-headers", + "authorization, last-event-id, x-code-bridge-client-session", + ) + .send() + .await + .expect("preflight response"); + + assert_eq!(response.status(), StatusCode::NO_CONTENT); + assert_eq!( + response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), + Some(&HeaderValue::from_static("http://[::1]:5173")) + ); + assert_eq!( + response.headers().get(ACCESS_CONTROL_ALLOW_METHODS), + Some(&HeaderValue::from_static(CORS_ALLOW_METHODS)) + ); + + service.handle.shutdown().await; + } + + #[tokio::test] + async fn includes_loopback_cors_headers_on_authenticated_event_streams() { + let service = start_test_service(Duration::from_secs(30), Duration::from_secs(30)).await; + let client = Client::new(); + let subscriber = register_subscriber( + &client, + &service.handle, + "browser-subscriber-1", + subscriber_capabilities(), + ) + .await; + + let response = client + .get(format!( + "{}/events/browser-subscriber-1", + service.handle.endpoint_url() + )) + .bearer_auth(service.handle.auth_secret()) + .header(CLIENT_SESSION_HEADER, subscriber.session_token.as_str()) + .header("last-event-id", "0") + .header(ORIGIN, "http://127.0.0.1:5173") + .send() + .await + .expect("event stream response"); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), + Some(&HeaderValue::from_static("http://127.0.0.1:5173")) + ); + assert_eq!( + response.headers().get(ACCESS_CONTROL_ALLOW_HEADERS), + Some(&HeaderValue::from_static(CORS_ALLOW_HEADERS)) + ); + drop(response); + + service.handle.shutdown().await; + } + + #[tokio::test] + async fn omits_cors_allow_origin_for_non_loopback_browser_origins() { + let service = start_test_service(Duration::from_secs(30), Duration::from_secs(30)).await; + let client = Client::new(); + let url = format!("{}/message", service.handle.endpoint_url()); + + let response = client + .request(Method::OPTIONS, &url) + .header(ORIGIN, "https://example.com") + .header("access-control-request-method", "POST") + .header("access-control-request-headers", "authorization") + .send() + .await + .expect("preflight response"); + + assert_eq!(response.status(), StatusCode::NO_CONTENT); + assert!( + response + .headers() + .get(ACCESS_CONTROL_ALLOW_ORIGIN) + .is_none() + ); + assert_eq!( + response.headers().get(VARY), + Some(&HeaderValue::from_static(CORS_VARY_HEADERS)) + ); + + service.handle.shutdown().await; + } + + #[tokio::test] + async fn includes_loopback_cors_headers_on_authenticated_message_responses() { + let service = start_test_service(Duration::from_secs(30), Duration::from_secs(30)).await; + let client = Client::new(); + let envelope = hello_envelope( + "browser-client-1", + service.handle.auth_secret(), + ClientRole::Producer, + producer_capabilities(), + ); + let response = client + .post(format!("{}/message", service.handle.endpoint_url())) + .bearer_auth(service.handle.auth_secret()) + .header(ORIGIN, "http://localhost:3000") + .json(&envelope) + .send() + .await + .expect("message response"); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), + Some(&HeaderValue::from_static("http://localhost:3000")) + ); + assert_eq!( + response.headers().get(ACCESS_CONTROL_ALLOW_HEADERS), + Some(&HeaderValue::from_static(CORS_ALLOW_HEADERS)) + ); + let payload = response + .json::() + .await + .expect("message response json") + .payload; + assert!(matches!(payload, BridgePayload::HelloResponse(_))); + + service.handle.shutdown().await; + } + #[tokio::test] async fn events_endpoint_requires_auth_and_registered_client() { let service = start_test_service(Duration::from_secs(30), Duration::from_secs(30)).await;