From b27710813569cbfa79e1f051ab7724f666bd6692 Mon Sep 17 00:00:00 2001 From: Jacob Page Date: Thu, 28 May 2026 17:03:06 -0700 Subject: [PATCH 1/2] feat: Allow hard-coded redirect URL The external CLI's OAuth clients are configured to only allow callback URLs that contain a specific port & "localhost" (not 127.0.0.1) as the hostname. Allow that to be specified. --- src/auth/pkce.rs | 137 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 127 insertions(+), 10 deletions(-) diff --git a/src/auth/pkce.rs b/src/auth/pkce.rs index 14d89bc..4cca9db 100644 --- a/src/auth/pkce.rs +++ b/src/auth/pkce.rs @@ -96,6 +96,7 @@ pub struct PkceAuthProvider { client_id: String, scopes: Vec, redirect_port: u16, + redirect_uri: Option, app_id: String, env_prefix: String, /// In-process token cache keyed by env. @@ -127,6 +128,7 @@ impl PkceAuthProvider { client_id: client_id.into(), scopes: scopes.iter().map(|s| s.as_ref().to_owned()).collect(), redirect_port: REDIRECT_PORT_DEFAULT, + redirect_uri: None, app_id: String::new(), env_prefix, cache: Arc::new(RwLock::new(HashMap::new())), @@ -140,6 +142,18 @@ impl PkceAuthProvider { self } + /// Overrides the redirect URI sent to the authorization server. + /// + /// By default the redirect URI is `http://127.0.0.1:{port}/callback`. Use + /// this when the OAuth client is allowlisted with a different URI, such as + /// `http://localhost:{port}/callback`. The local listener always binds to + /// `127.0.0.1` regardless of what is set here. + #[must_use] + pub fn with_redirect_uri(mut self, uri: impl Into) -> Self { + self.redirect_uri = Some(uri.into()); + self + } + /// Sets the application id used as the keychain service prefix. #[must_use] pub fn with_app_id(mut self, app_id: impl Into) -> Self { @@ -170,6 +184,27 @@ impl PkceAuthProvider { std::env::var(&key).unwrap_or_else(|_| self.token_url.clone()) } + fn effective_redirect_uri(&self) -> String { + self.redirect_uri.clone().unwrap_or_else(|| { + format!("http://127.0.0.1:{}/callback", self.redirect_port) + }) + } + + /// Parses the effective redirect URI and returns `(bind_port, callback_path)`. + fn parse_redirect_uri(&self) -> Result<(u16, String)> { + let uri_str = self.effective_redirect_uri(); + let parsed = url::Url::parse(&uri_str) + .map_err(|e| CliCoreError::message(format!("invalid redirect URI '{uri_str}': {e}")))?; + let port = parsed + .port() + .or_else(|| parsed.port_or_known_default()) + .ok_or_else(|| { + CliCoreError::message(format!("redirect URI '{uri_str}' has no port")) + })?; + let path = parsed.path().to_owned(); + Ok((port, path)) + } + fn keychain_service(&self, env: &str) -> String { if self.app_id.is_empty() { format!("{}/{}", self.name, env) @@ -245,7 +280,7 @@ impl PkceAuthProvider { let state = random_state(); let client_id = self.effective_client_id(); let auth_url = self.effective_auth_url(); - let redirect_uri = format!("http://127.0.0.1:{}/callback", self.redirect_port); + let redirect_uri = self.effective_redirect_uri(); let scope = self.scopes.join(" "); let auth_params = [ @@ -260,13 +295,14 @@ impl PkceAuthProvider { let url = url::Url::parse_with_params(&auth_url, &auth_params) .map_err(|err| CliCoreError::message(format!("invalid auth URL: {err}")))?; + let (bind_port, callback_path) = self.parse_redirect_uri()?; + // Start the local callback server before opening the browser so the // redirect lands as soon as the user approves. - let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], self.redirect_port))) + let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], bind_port))) .map_err(|err| { CliCoreError::message(format!( - "failed to bind callback server on port {}: {err}", - self.redirect_port + "failed to bind callback server on port {bind_port}: {err}" )) })?; @@ -274,7 +310,8 @@ impl PkceAuthProvider { tracing::info!("If the browser does not open, visit:\n {url}"); drop(open::that(url.as_str())); - let code = wait_for_callback(listener, &state, Duration::from_secs(120)).await?; + let code = + wait_for_callback(listener, &state, &callback_path, Duration::from_secs(120)).await?; self.exchange_code_for_token(&code, &code_verifier, env) .await } @@ -285,7 +322,7 @@ impl PkceAuthProvider { code_verifier: &str, env: &str, ) -> Result { - let redirect_uri = format!("http://127.0.0.1:{}/callback", self.redirect_port); + let redirect_uri = self.effective_redirect_uri(); let client_id = self.effective_client_id(); let token_url = self.effective_token_url(); @@ -406,7 +443,7 @@ fn random_state() -> String { URL_SAFE_NO_PAD.encode(bytes) } -/// Waits for the OAuth callback on the given listener, validates state. +/// Waits for the OAuth callback on the given listener, validates state and path. /// /// Accepts connections in a loop so that stray connections (port scanners, /// browser preflight requests) do not consume the single callback attempt. @@ -414,6 +451,7 @@ fn random_state() -> String { async fn wait_for_callback( listener: TcpListener, expected_state: &str, + expected_path: &str, timeout: Duration, ) -> Result { use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -425,6 +463,7 @@ async fn wait_for_callback( .map_err(|err| CliCoreError::message(format!("callback server setup failed: {err}")))?; let expected_state = expected_state.to_owned(); + let expected_path = expected_path.to_owned(); let result = tokio::time::timeout(timeout, async move { loop { let (mut stream, _) = match listener.accept().await { @@ -445,6 +484,17 @@ async fn wait_for_callback( }; let request = String::from_utf8_lossy(&buf[..n]); + if extract_request_path(&request).as_deref() != Some(expected_path.as_str()) { + drop( + stream + .write_all( + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n", + ) + .await, + ); + continue; + } + let code = extract_query_param(&request, "code"); let state = extract_query_param(&request, "state"); @@ -467,12 +517,17 @@ async fn wait_for_callback( match result { Ok(inner) => inner, - Err(_) => Err(CliCoreError::message( - "timed out waiting for OAuth callback", - )), + Err(_) => Err(CliCoreError::message("timed out waiting for OAuth callback")), } } +/// Extracts the path component from an HTTP request line (without query string). +fn extract_request_path(request: &str) -> Option { + let line = request.lines().next()?; + let path_with_query = line.split_whitespace().nth(1)?; + Some(path_with_query.split_once('?').map_or(path_with_query, |(p, _)| p).to_owned()) +} + /// Extracts a query parameter value from an HTTP request line. fn extract_query_param(request: &str, name: &str) -> Option { let line = request.lines().next()?; @@ -565,6 +620,68 @@ mod tests { ); } + #[test] + fn redirect_uri_default_uses_127_0_0_1_and_redirect_port() { + let provider = test_provider().with_redirect_port(9000); + assert_eq!( + provider.effective_redirect_uri(), + "http://127.0.0.1:9000/callback" + ); + } + + #[test] + fn with_redirect_uri_overrides_default() { + let provider = test_provider().with_redirect_uri("http://localhost:8080/auth/callback"); + assert_eq!( + provider.effective_redirect_uri(), + "http://localhost:8080/auth/callback" + ); + } + + #[test] + fn parse_redirect_uri_extracts_port_and_path_from_default() { + let provider = test_provider().with_redirect_port(9000); + let (port, path) = provider.parse_redirect_uri().expect("valid URI"); + assert_eq!(port, 9000); + assert_eq!(path, "/callback"); + } + + #[test] + fn parse_redirect_uri_extracts_port_and_path_from_custom_uri() { + let provider = + test_provider().with_redirect_uri("http://localhost:8080/auth/callback"); + let (port, path) = provider.parse_redirect_uri().expect("valid URI"); + assert_eq!(port, 8080); + assert_eq!(path, "/auth/callback"); + } + + #[test] + fn with_redirect_uri_does_not_affect_listener_host() { + // The port is derived from the URI, but the listener always binds to + // 127.0.0.1 — this test confirms the URI host does not change that. + let provider = test_provider().with_redirect_uri("http://localhost:7777/callback"); + let (port, _) = provider.parse_redirect_uri().expect("valid URI"); + assert_eq!(port, 7777); + // Caller uses 127.0.0.1 for bind regardless; SocketAddr construction + // is in run_pkce_flow and is not repeated here. + } + + #[test] + fn extract_request_path_strips_query_string() { + assert_eq!( + extract_request_path("GET /auth/callback?code=abc&state=xyz HTTP/1.1\r\n"), + Some("/auth/callback".to_owned()), + ); + } + + #[test] + fn extract_request_path_handles_no_query_string() { + assert_eq!( + extract_request_path("GET /callback HTTP/1.1\r\n"), + Some("/callback".to_owned()), + ); + } + #[test] fn extract_query_param_skips_malformed_pairs() { let request = "GET /callback?foo&code=abc123&state=xyz HTTP/1.1\r\nHost: localhost\r\n"; From b32670551c12c071816c75871b3155b1b64dad67 Mon Sep 17 00:00:00 2001 From: Jacob Page Date: Thu, 28 May 2026 17:12:52 -0700 Subject: [PATCH 2/2] Formatting :( --- src/auth/pkce.rs | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/auth/pkce.rs b/src/auth/pkce.rs index 4cca9db..d0478d5 100644 --- a/src/auth/pkce.rs +++ b/src/auth/pkce.rs @@ -185,9 +185,9 @@ impl PkceAuthProvider { } fn effective_redirect_uri(&self) -> String { - self.redirect_uri.clone().unwrap_or_else(|| { - format!("http://127.0.0.1:{}/callback", self.redirect_port) - }) + self.redirect_uri + .clone() + .unwrap_or_else(|| format!("http://127.0.0.1:{}/callback", self.redirect_port)) } /// Parses the effective redirect URI and returns `(bind_port, callback_path)`. @@ -299,8 +299,8 @@ impl PkceAuthProvider { // Start the local callback server before opening the browser so the // redirect lands as soon as the user approves. - let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], bind_port))) - .map_err(|err| { + let listener = + TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], bind_port))).map_err(|err| { CliCoreError::message(format!( "failed to bind callback server on port {bind_port}: {err}" )) @@ -487,9 +487,7 @@ async fn wait_for_callback( if extract_request_path(&request).as_deref() != Some(expected_path.as_str()) { drop( stream - .write_all( - b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n", - ) + .write_all(b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n") .await, ); continue; @@ -517,7 +515,9 @@ async fn wait_for_callback( match result { Ok(inner) => inner, - Err(_) => Err(CliCoreError::message("timed out waiting for OAuth callback")), + Err(_) => Err(CliCoreError::message( + "timed out waiting for OAuth callback", + )), } } @@ -525,7 +525,12 @@ async fn wait_for_callback( fn extract_request_path(request: &str) -> Option { let line = request.lines().next()?; let path_with_query = line.split_whitespace().nth(1)?; - Some(path_with_query.split_once('?').map_or(path_with_query, |(p, _)| p).to_owned()) + Some( + path_with_query + .split_once('?') + .map_or(path_with_query, |(p, _)| p) + .to_owned(), + ) } /// Extracts a query parameter value from an HTTP request line. @@ -648,8 +653,7 @@ mod tests { #[test] fn parse_redirect_uri_extracts_port_and_path_from_custom_uri() { - let provider = - test_provider().with_redirect_uri("http://localhost:8080/auth/callback"); + let provider = test_provider().with_redirect_uri("http://localhost:8080/auth/callback"); let (port, path) = provider.parse_redirect_uri().expect("valid URI"); assert_eq!(port, 8080); assert_eq!(path, "/auth/callback");