Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 129 additions & 8 deletions src/auth/pkce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ pub struct PkceAuthProvider {
client_id: String,
scopes: Vec<String>,
redirect_port: u16,
redirect_uri: Option<String>,
app_id: String,
env_prefix: String,
/// In-process token cache keyed by env.
Expand Down Expand Up @@ -127,6 +128,7 @@ impl PkceAuthProvider {
client_id: client_id.into(),
scopes: scopes.iter().map(|s| s.as_ref().to_owned()).collect(),
redirect_port: REDIRECT_PORT_DEFAULT,
redirect_uri: None,
app_id: String::new(),
env_prefix,
cache: Arc::new(RwLock::new(HashMap::new())),
Expand All @@ -140,6 +142,18 @@ impl PkceAuthProvider {
self
}

/// Overrides the redirect URI sent to the authorization server.
///
/// By default the redirect URI is `http://127.0.0.1:{port}/callback`. Use
/// this when the OAuth client is allowlisted with a different URI, such as
/// `http://localhost:{port}/callback`. The local listener always binds to
/// `127.0.0.1` regardless of what is set here.
#[must_use]
pub fn with_redirect_uri(mut self, uri: impl Into<String>) -> Self {
self.redirect_uri = Some(uri.into());
self
}

/// Sets the application id used as the keychain service prefix.
#[must_use]
pub fn with_app_id(mut self, app_id: impl Into<String>) -> Self {
Expand Down Expand Up @@ -170,6 +184,27 @@ impl PkceAuthProvider {
std::env::var(&key).unwrap_or_else(|_| self.token_url.clone())
}

fn effective_redirect_uri(&self) -> String {
self.redirect_uri
.clone()
.unwrap_or_else(|| format!("http://127.0.0.1:{}/callback", self.redirect_port))
}

/// Parses the effective redirect URI and returns `(bind_port, callback_path)`.
fn parse_redirect_uri(&self) -> Result<(u16, String)> {
let uri_str = self.effective_redirect_uri();
let parsed = url::Url::parse(&uri_str)
.map_err(|e| CliCoreError::message(format!("invalid redirect URI '{uri_str}': {e}")))?;
let port = parsed
.port()
.or_else(|| parsed.port_or_known_default())
.ok_or_else(|| {
CliCoreError::message(format!("redirect URI '{uri_str}' has no port"))
})?;
let path = parsed.path().to_owned();
Ok((port, path))
}

fn keychain_service(&self, env: &str) -> String {
if self.app_id.is_empty() {
format!("{}/{}", self.name, env)
Expand Down Expand Up @@ -245,7 +280,7 @@ impl PkceAuthProvider {
let state = random_state();
let client_id = self.effective_client_id();
let auth_url = self.effective_auth_url();
let redirect_uri = format!("http://127.0.0.1:{}/callback", self.redirect_port);
let redirect_uri = self.effective_redirect_uri();
let scope = self.scopes.join(" ");

let auth_params = [
Expand All @@ -260,21 +295,23 @@ impl PkceAuthProvider {
let url = url::Url::parse_with_params(&auth_url, &auth_params)
.map_err(|err| CliCoreError::message(format!("invalid auth URL: {err}")))?;

let (bind_port, callback_path) = self.parse_redirect_uri()?;

// Start the local callback server before opening the browser so the
// redirect lands as soon as the user approves.
let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], self.redirect_port)))
.map_err(|err| {
let listener =
TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], bind_port))).map_err(|err| {
CliCoreError::message(format!(
"failed to bind callback server on port {}: {err}",
self.redirect_port
"failed to bind callback server on port {bind_port}: {err}"
))
})?;

tracing::info!("Opening browser for authentication…");
tracing::info!("If the browser does not open, visit:\n {url}");
drop(open::that(url.as_str()));

let code = wait_for_callback(listener, &state, Duration::from_secs(120)).await?;
let code =
wait_for_callback(listener, &state, &callback_path, Duration::from_secs(120)).await?;
self.exchange_code_for_token(&code, &code_verifier, env)
.await
}
Expand All @@ -285,7 +322,7 @@ impl PkceAuthProvider {
code_verifier: &str,
env: &str,
) -> Result<StoredToken> {
let redirect_uri = format!("http://127.0.0.1:{}/callback", self.redirect_port);
let redirect_uri = self.effective_redirect_uri();
let client_id = self.effective_client_id();
let token_url = self.effective_token_url();

Expand Down Expand Up @@ -406,14 +443,15 @@ fn random_state() -> String {
URL_SAFE_NO_PAD.encode(bytes)
}

/// Waits for the OAuth callback on the given listener, validates state.
/// Waits for the OAuth callback on the given listener, validates state and path.
///
/// Accepts connections in a loop so that stray connections (port scanners,
/// browser preflight requests) do not consume the single callback attempt.
/// Uses async I/O so the future is properly cancelled on Ctrl+C.
async fn wait_for_callback(
listener: TcpListener,
expected_state: &str,
expected_path: &str,
timeout: Duration,
) -> Result<String> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
Expand All @@ -425,6 +463,7 @@ async fn wait_for_callback(
.map_err(|err| CliCoreError::message(format!("callback server setup failed: {err}")))?;

let expected_state = expected_state.to_owned();
let expected_path = expected_path.to_owned();
let result = tokio::time::timeout(timeout, async move {
loop {
let (mut stream, _) = match listener.accept().await {
Expand All @@ -445,6 +484,15 @@ async fn wait_for_callback(
};
let request = String::from_utf8_lossy(&buf[..n]);

if extract_request_path(&request).as_deref() != Some(expected_path.as_str()) {
drop(
stream
.write_all(b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n")
.await,
);
continue;
}

let code = extract_query_param(&request, "code");
let state = extract_query_param(&request, "state");

Expand Down Expand Up @@ -473,6 +521,18 @@ async fn wait_for_callback(
}
}

/// Extracts the path component from an HTTP request line (without query string).
fn extract_request_path(request: &str) -> Option<String> {
let line = request.lines().next()?;
let path_with_query = line.split_whitespace().nth(1)?;
Some(
path_with_query
.split_once('?')
.map_or(path_with_query, |(p, _)| p)
.to_owned(),
)
}

/// Extracts a query parameter value from an HTTP request line.
fn extract_query_param(request: &str, name: &str) -> Option<String> {
let line = request.lines().next()?;
Expand Down Expand Up @@ -565,6 +625,67 @@ mod tests {
);
}

#[test]
fn redirect_uri_default_uses_127_0_0_1_and_redirect_port() {
let provider = test_provider().with_redirect_port(9000);
assert_eq!(
provider.effective_redirect_uri(),
"http://127.0.0.1:9000/callback"
);
}

#[test]
fn with_redirect_uri_overrides_default() {
let provider = test_provider().with_redirect_uri("http://localhost:8080/auth/callback");
assert_eq!(
provider.effective_redirect_uri(),
"http://localhost:8080/auth/callback"
);
}

#[test]
fn parse_redirect_uri_extracts_port_and_path_from_default() {
let provider = test_provider().with_redirect_port(9000);
let (port, path) = provider.parse_redirect_uri().expect("valid URI");
assert_eq!(port, 9000);
assert_eq!(path, "/callback");
}

#[test]
fn parse_redirect_uri_extracts_port_and_path_from_custom_uri() {
let provider = test_provider().with_redirect_uri("http://localhost:8080/auth/callback");
let (port, path) = provider.parse_redirect_uri().expect("valid URI");
assert_eq!(port, 8080);
assert_eq!(path, "/auth/callback");
}

#[test]
fn with_redirect_uri_does_not_affect_listener_host() {
// The port is derived from the URI, but the listener always binds to
// 127.0.0.1 — this test confirms the URI host does not change that.
let provider = test_provider().with_redirect_uri("http://localhost:7777/callback");
let (port, _) = provider.parse_redirect_uri().expect("valid URI");
assert_eq!(port, 7777);
// Caller uses 127.0.0.1 for bind regardless; SocketAddr construction
// is in run_pkce_flow and is not repeated here.
}

#[test]
fn extract_request_path_strips_query_string() {
assert_eq!(
extract_request_path("GET /auth/callback?code=abc&state=xyz HTTP/1.1\r\n"),
Some("/auth/callback".to_owned()),
);
}

#[test]
fn extract_request_path_handles_no_query_string() {
assert_eq!(
extract_request_path("GET /callback HTTP/1.1\r\n"),
Some("/callback".to_owned()),
);
}

#[test]
fn extract_query_param_skips_malformed_pairs() {
let request = "GET /callback?foo&code=abc123&state=xyz HTTP/1.1\r\nHost: localhost\r\n";
Expand Down