Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src-tauri/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ pub struct ConnectionParams {
pub k8s_resource_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub k8s_port: Option<u16>,
/// SQL run on every new physical connection in the pool (e.g. `SET` /
/// `set_config` for session-scoped settings such as bypassing RLS).
/// Statements are separated by `;`. Runs per pooled connection so the
/// setting applies to every query regardless of which connection the
/// pool hands out.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub startup_script: Option<String>,
// Connection ID for stable pooling (not persisted, set at runtime)
#[serde(skip_serializing_if = "Option::is_none")]
pub connection_id: Option<String>,
Expand Down
1 change: 1 addition & 0 deletions src-tauri/src/plugins/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ mod tests {
k8s_resource_type: None,
k8s_resource_name: None,
k8s_port: None,
startup_script: None,
connection_id: Some("conn-1".to_string()),
}
}
Expand Down
155 changes: 130 additions & 25 deletions src-tauri/src/pool_manager.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::models::ConnectionParams;
use deadpool_postgres::{Manager as PgPoolManager, Pool as PgPool};
use deadpool_postgres::{Hook as PgHook, HookError as PgHookError, Manager as PgPoolManager, Pool as PgPool};
use once_cell::sync::Lazy;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::client::{verify_server_cert_signed_by_trust_anchor, WebPkiServerVerifier};
Expand All @@ -11,7 +11,8 @@ use rustls::server::ParsedCertificate;
use rustls::{DigitallySignedStruct};
use rustls::{ClientConfig, Error as TlsError, RootCertStore};
use rustls_platform_verifier::BuilderVerifierExt;
use sqlx::{sqlite::SqliteConnectOptions, MySql, Pool, Sqlite};
use sha2::{Digest, Sha256};
use sqlx::{sqlite::SqliteConnectOptions, ConnectOptions, Connection, Executor, MySql, Pool, Sqlite};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
Expand Down Expand Up @@ -51,6 +52,11 @@ static SQLITE_POOLS: Lazy<PoolMap<Sqlite>> = Lazy::new(|| Arc::new(RwLock::new(H
const DEFAULT_MYSQL_CONNECT_TIMEOUT_MS: u64 = 60_000;
const DEFAULT_MYSQL_TIMEZONE: &str = "SYSTEM";

/// SQLite is file-based so the preflight is effectively local, but a custom
/// VFS or a path on a stalled network mount could still hang it; bound it so a
/// broken script can never wedge pool creation indefinitely.
const SQLITE_STARTUP_SCRIPT_TIMEOUT_MS: u64 = 30_000;

fn mysql_setting_value(key: &str) -> Option<serde_json::Value> {
crate::config::get_cached_config()
.plugins
Expand Down Expand Up @@ -116,10 +122,23 @@ pub(crate) fn build_connection_key(
)
};

if let Some(tls_key) = tls_key {
let key = if let Some(tls_key) = tls_key {
format!("{base_key}:{tls_key}")
} else {
base_key
};

// Fold the startup script into the key so editing it forces a fresh pool
// (whose new connections run the new script) instead of silently reusing
// the cached pool keyed only by connection_id. Hashed to keep the key
// bounded; only present when a script is set, so script-free connections
// keep their existing keys.
match startup_script(params) {
Some(script) => {
let digest = Sha256::digest(script.as_bytes());
format!("{key}:startup:{digest:x}")
}
None => key,
}
}

Expand Down Expand Up @@ -469,6 +488,54 @@ fn build_sqlite_connectoptions(params: &ConnectionParams) -> SqliteConnectOption
SqliteConnectOptions::new().filename(params.database.to_string())
}

/// Return the connection's startup script if it is set and not blank.
/// Whitespace-only scripts are treated as absent so the per-connection
/// hook is skipped entirely rather than issuing an empty query.
fn startup_script(params: &ConnectionParams) -> Option<String> {
params
.startup_script
.as_ref()
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(str::to_owned)
}

/// Format a startup-script execution failure so the surfaced error clearly
/// names the startup script as the cause, instead of reading like a bad host
/// or wrong credentials.
fn startup_script_error(err: impl std::fmt::Display) -> String {
format!("Startup script failed: {err}")
}

/// Run the startup script once on a throwaway connection so a broken script
/// fails fast with a clearly attributed error. The per-connection hooks
/// (`after_connect`/`post_create`) still run it on every pooled connection;
/// this preflight exists only for early, well-labelled failures: sqlx swallows
/// `after_connect` errors and retries until the acquire timeout, which would
/// otherwise report a misleading "pool timed out". A failure to open the
/// connection is returned verbatim so genuine connectivity problems are not
/// mislabelled as startup-script errors.
async fn run_mysql_startup_script(
options: &sqlx::mysql::MySqlConnectOptions,
script: &str,
) -> Result<(), String> {
let mut conn = options.connect().await.map_err(|e| e.to_string())?;
let outcome = conn.execute(script).await;
let _ = conn.close().await;
outcome.map(|_| ()).map_err(startup_script_error)
}

/// SQLite counterpart to [`run_mysql_startup_script`].
async fn run_sqlite_startup_script(
options: &SqliteConnectOptions,
script: &str,
) -> Result<(), String> {
let mut conn = options.connect().await.map_err(|e| e.to_string())?;
let outcome = conn.execute(script).await;
let _ = conn.close().await;
outcome.map(|_| ()).map_err(startup_script_error)
}

pub async fn get_mysql_pool(params: &ConnectionParams) -> Result<Pool<MySql>, String> {
let connection_id = params.connection_id.as_deref();
get_mysql_pool_with_id(params, connection_id).await
Expand Down Expand Up @@ -525,12 +592,25 @@ async fn get_mysql_pool_for_database_with_id(
"connectTimeout",
DEFAULT_MYSQL_CONNECT_TIMEOUT_MS,
));
let pool = tokio::time::timeout(
connect_timeout,
sqlx::mysql::MySqlPoolOptions::new()
.max_connections(10)
.connect_with(options),
)
let mut pool_options = sqlx::mysql::MySqlPoolOptions::new().max_connections(10);
if let Some(script) = startup_script(params) {
tokio::time::timeout(connect_timeout, run_mysql_startup_script(&options, &script))
.await
.map_err(|_| {
format!(
"Timed out running MySQL startup script after {} ms",
connect_timeout.as_millis()
)
})??;
pool_options = pool_options.after_connect(move |conn, _meta| {
let script = script.clone();
Box::pin(async move {
conn.execute(script.as_str()).await?;
Ok(())
})
});
}
let pool = tokio::time::timeout(connect_timeout, pool_options.connect_with(options))
.await
.map_err(|_| {
format!(
Expand Down Expand Up @@ -597,14 +677,24 @@ pub async fn get_postgres_pool_with_id(
e
})?;

let pool = PgPool::builder(PgPoolManager::new(cfg, tls_connector))
.max_size(10)
.build()
.map_err(|e| {
let detail = format_error_chain(&e);
log::error!("Failed to create PostgreSQL connection pool: {}", detail);
detail
})?;
let mut builder = PgPool::builder(PgPoolManager::new(cfg, tls_connector)).max_size(10);
if let Some(script) = startup_script(params) {
builder = builder.post_create(PgHook::async_fn(move |client, _metrics| {
let script = script.clone();
Box::pin(async move {
client
.batch_execute(&script)
.await
.map_err(|e| PgHookError::message(startup_script_error(format_error_chain(&e))))?;
Ok(())
})
}));
}
let pool = builder.build().map_err(|e| {
let detail = format_error_chain(&e);
log::error!("Failed to create PostgreSQL connection pool: {}", detail);
detail
})?;

log::info!(
"PostgreSQL connection pool created successfully for: {} (key: {})",
Expand Down Expand Up @@ -652,14 +742,29 @@ pub async fn get_sqlite_pool_with_id(
key
);
let options = build_sqlite_connectoptions(params);
let pool = sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(5) // SQLite has lower concurrency needs
.connect_with(options)
.await
.map_err(|e| {
log::error!("Failed to create SQLite connection pool: {}", e);
e.to_string()
})?;
let mut pool_options = sqlx::sqlite::SqlitePoolOptions::new().max_connections(5); // SQLite has lower concurrency needs
if let Some(script) = startup_script(params) {
let timeout = Duration::from_millis(SQLITE_STARTUP_SCRIPT_TIMEOUT_MS);
tokio::time::timeout(timeout, run_sqlite_startup_script(&options, &script))
.await
.map_err(|_| {
format!(
"Timed out running SQLite startup script after {} ms",
timeout.as_millis()
)
})??;
pool_options = pool_options.after_connect(move |conn, _meta| {
let script = script.clone();
Box::pin(async move {
conn.execute(script.as_str()).await?;
Ok(())
})
});
}
let pool = pool_options.connect_with(options).await.map_err(|e| {
log::error!("Failed to create SQLite connection pool: {}", e);
e.to_string()
})?;

log::info!(
"SQLite connection pool created successfully for: {} (key: {})",
Expand Down
123 changes: 123 additions & 0 deletions src-tauri/src/pool_manager_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,38 @@ mod tests {
);
}

#[test]
fn pool_key_changes_when_startup_script_changes() {
let none = connection_params("postgres", Some("require"));
let mut script_a = none.clone();
script_a.startup_script = Some("SET app.bypass_rls = 'on';".to_string());
let mut script_b = none.clone();
script_b.startup_script = Some("SET app.bypass_rls = 'off';".to_string());

let key_none = build_connection_key(&none, Some("conn-1"));
let key_a = build_connection_key(&script_a, Some("conn-1"));
let key_b = build_connection_key(&script_b, Some("conn-1"));

// A script changes the key, and different scripts differ — otherwise an
// edited startup script would silently reuse the old cached pool.
assert_ne!(key_none, key_a);
assert_ne!(key_a, key_b);
}

#[test]
fn pool_key_ignores_blank_startup_script() {
let none = connection_params("postgres", Some("require"));
let mut blank = none.clone();
blank.startup_script = Some(" \n\t".to_string());

// Whitespace-only scripts are treated as absent (no hook runs), so they
// must not fragment the pool away from the no-script connection.
assert_eq!(
build_connection_key(&none, Some("conn-1")),
build_connection_key(&blank, Some("conn-1"))
);
}

#[test]
fn mysql_options_accept_snake_case_verify_ssl_modes() {
let verify_ca = mysql_params("verify_ca");
Expand Down Expand Up @@ -372,3 +404,94 @@ mod postgres_tls_connector_tests {
let _ = std::fs::remove_file(&file_path);
}
}

#[cfg(test)]
mod startup_script_tests {
use crate::models::{ConnectionParams, DatabaseSelection};
use crate::pool_manager::{close_pool_with_id, get_sqlite_pool_with_id};
use tempfile::NamedTempFile;

fn sqlite_params(path: &str, startup_script: Option<&str>) -> ConnectionParams {
ConnectionParams {
driver: "sqlite".to_string(),
database: DatabaseSelection::Single(path.to_string()),
startup_script: startup_script.map(ToOwned::to_owned),
..Default::default()
}
}

#[tokio::test]
async fn startup_script_runs_on_each_new_connection() {
let file = NamedTempFile::new().expect("temp file");
let path = file.path().to_str().expect("utf8 path").to_string();
// Unique connection id keeps this pool out of other tests' cached pools.
let conn_id = format!("startup-runs-{}", ulid::Ulid::new());

let params = sqlite_params(
&path,
Some(
"CREATE TABLE IF NOT EXISTS startup_marker (id INTEGER); \
INSERT INTO startup_marker (id) VALUES (1);",
),
);

let pool = get_sqlite_pool_with_id(&params, Some(&conn_id))
.await
.expect("pool should be created");

// The marker table only exists if the startup script ran on the
// physical connection the pool just handed out.
let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM startup_marker")
.fetch_one(&pool)
.await
.expect("startup_marker table should exist");
assert!(count >= 1, "expected at least one startup INSERT, got {count}");

close_pool_with_id(&params, Some(&conn_id)).await;
}

#[tokio::test]
async fn blank_startup_script_is_skipped() {
let file = NamedTempFile::new().expect("temp file");
let path = file.path().to_str().expect("utf8 path").to_string();
let conn_id = format!("startup-blank-{}", ulid::Ulid::new());

// A whitespace-only script must be treated as absent: if it were run
// as SQL the connection would fail and `SELECT 1` below would error.
let params = sqlite_params(&path, Some(" \n "));

let pool = get_sqlite_pool_with_id(&params, Some(&conn_id))
.await
.expect("pool should be created");

let (one,): (i64,) = sqlx::query_as("SELECT 1")
.fetch_one(&pool)
.await
.expect("query on pool with blank startup script should work");
assert_eq!(one, 1);

close_pool_with_id(&params, Some(&conn_id)).await;
}

#[tokio::test]
async fn invalid_startup_script_surfaces_attributed_error() {
let file = NamedTempFile::new().expect("temp file");
let path = file.path().to_str().expect("utf8 path").to_string();
let conn_id = format!("startup-invalid-{}", ulid::Ulid::new());

let params = sqlite_params(&path, Some("THIS IS NOT VALID SQL;"));

// A broken startup script must fail the connection with an error that
// clearly names the startup script as the cause, rather than sqlx's
// misleading "pool timed out" or a generic connection error.
let err = get_sqlite_pool_with_id(&params, Some(&conn_id))
.await
.expect_err("invalid startup script should fail the connection");
assert!(
err.contains("Startup script failed"),
"error should be attributed to the startup script, got: {err}"
);

close_pool_with_id(&params, Some(&conn_id)).await;
}
}
Loading