From 14eeac9921d6d110ce932f061cfeb54eb2e136ca Mon Sep 17 00:00:00 2001 From: Jonas Strasel Date: Mon, 22 Jun 2026 16:31:17 +0200 Subject: [PATCH 1/5] feat(connections): add per-connection startup script Run optional SQL on every new pooled connection (MySQL/SQLite after_connect, Postgres deadpool post_create) so session settings like set_config apply to every query. Editable via a new "Advanced" tab in the connection editor. Closes #350 --- src-tauri/src/models.rs | 7 ++ src-tauri/src/plugins/driver.rs | 1 + src-tauri/src/pool_manager.rs | 81 +++++++++++----- src-tauri/src/pool_manager_tests.rs | 98 ++++++++++++++++++++ src/components/modals/NewConnectionModal.tsx | 44 ++++++++- src/contexts/DatabaseContext.ts | 1 + src/i18n/locales/de.json | 6 +- src/i18n/locales/en.json | 6 +- src/i18n/locales/es.json | 6 +- src/i18n/locales/fr.json | 6 +- src/i18n/locales/it.json | 6 +- src/i18n/locales/ja.json | 6 +- src/i18n/locales/ru.json | 6 +- src/i18n/locales/zh.json | 6 +- src/utils/connections.ts | 2 + 15 files changed, 247 insertions(+), 35 deletions(-) diff --git a/src-tauri/src/models.rs b/src-tauri/src/models.rs index eba9d080..a97bde45 100644 --- a/src-tauri/src/models.rs +++ b/src-tauri/src/models.rs @@ -157,6 +157,13 @@ pub struct ConnectionParams { pub k8s_resource_name: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub k8s_port: Option, + /// SQL run on every new physical connection in the pool (e.g. `SET` / + /// `set_config` for session-scoped settings such as bypassing RLS). + /// Statements are separated by `;`. Runs per pooled connection so the + /// setting applies to every query regardless of which connection the + /// pool hands out. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub startup_script: Option, // Connection ID for stable pooling (not persisted, set at runtime) #[serde(skip_serializing_if = "Option::is_none")] pub connection_id: Option, diff --git a/src-tauri/src/plugins/driver.rs b/src-tauri/src/plugins/driver.rs index 70aa77b9..b63e7f11 100644 --- a/src-tauri/src/plugins/driver.rs +++ b/src-tauri/src/plugins/driver.rs @@ -840,6 +840,7 @@ mod tests { k8s_resource_type: None, k8s_resource_name: None, k8s_port: None, + startup_script: None, connection_id: Some("conn-1".to_string()), } } diff --git a/src-tauri/src/pool_manager.rs b/src-tauri/src/pool_manager.rs index 12351aaa..09bca213 100644 --- a/src-tauri/src/pool_manager.rs +++ b/src-tauri/src/pool_manager.rs @@ -1,5 +1,5 @@ use crate::models::ConnectionParams; -use deadpool_postgres::{Manager as PgPoolManager, Pool as PgPool}; +use deadpool_postgres::{Hook as PgHook, HookError as PgHookError, Manager as PgPoolManager, Pool as PgPool}; use once_cell::sync::Lazy; use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; use rustls::client::{verify_server_cert_signed_by_trust_anchor, WebPkiServerVerifier}; @@ -11,7 +11,7 @@ use rustls::server::ParsedCertificate; use rustls::{DigitallySignedStruct}; use rustls::{ClientConfig, Error as TlsError, RootCertStore}; use rustls_platform_verifier::BuilderVerifierExt; -use sqlx::{sqlite::SqliteConnectOptions, MySql, Pool, Sqlite}; +use sqlx::{sqlite::SqliteConnectOptions, Executor, MySql, Pool, Sqlite}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -469,6 +469,18 @@ fn build_sqlite_connectoptions(params: &ConnectionParams) -> SqliteConnectOption SqliteConnectOptions::new().filename(params.database.to_string()) } +/// Return the connection's startup script if it is set and not blank. +/// Whitespace-only scripts are treated as absent so the per-connection +/// hook is skipped entirely rather than issuing an empty query. +fn startup_script(params: &ConnectionParams) -> Option { + params + .startup_script + .as_ref() + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(str::to_owned) +} + pub async fn get_mysql_pool(params: &ConnectionParams) -> Result, String> { let connection_id = params.connection_id.as_deref(); get_mysql_pool_with_id(params, connection_id).await @@ -525,12 +537,17 @@ async fn get_mysql_pool_for_database_with_id( "connectTimeout", DEFAULT_MYSQL_CONNECT_TIMEOUT_MS, )); - let pool = tokio::time::timeout( - connect_timeout, - sqlx::mysql::MySqlPoolOptions::new() - .max_connections(10) - .connect_with(options), - ) + let mut pool_options = sqlx::mysql::MySqlPoolOptions::new().max_connections(10); + if let Some(script) = startup_script(params) { + pool_options = pool_options.after_connect(move |conn, _meta| { + let script = script.clone(); + Box::pin(async move { + conn.execute(script.as_str()).await?; + Ok(()) + }) + }); + } + let pool = tokio::time::timeout(connect_timeout, pool_options.connect_with(options)) .await .map_err(|_| { format!( @@ -597,14 +614,24 @@ pub async fn get_postgres_pool_with_id( e })?; - let pool = PgPool::builder(PgPoolManager::new(cfg, tls_connector)) - .max_size(10) - .build() - .map_err(|e| { - let detail = format_error_chain(&e); - log::error!("Failed to create PostgreSQL connection pool: {}", detail); - detail - })?; + let mut builder = PgPool::builder(PgPoolManager::new(cfg, tls_connector)).max_size(10); + if let Some(script) = startup_script(params) { + builder = builder.post_create(PgHook::async_fn(move |client, _metrics| { + let script = script.clone(); + Box::pin(async move { + client + .batch_execute(&script) + .await + .map_err(PgHookError::Backend)?; + Ok(()) + }) + })); + } + let pool = builder.build().map_err(|e| { + let detail = format_error_chain(&e); + log::error!("Failed to create PostgreSQL connection pool: {}", detail); + detail + })?; log::info!( "PostgreSQL connection pool created successfully for: {} (key: {})", @@ -652,14 +679,20 @@ pub async fn get_sqlite_pool_with_id( key ); let options = build_sqlite_connectoptions(params); - let pool = sqlx::sqlite::SqlitePoolOptions::new() - .max_connections(5) // SQLite has lower concurrency needs - .connect_with(options) - .await - .map_err(|e| { - log::error!("Failed to create SQLite connection pool: {}", e); - e.to_string() - })?; + let mut pool_options = sqlx::sqlite::SqlitePoolOptions::new().max_connections(5); // SQLite has lower concurrency needs + if let Some(script) = startup_script(params) { + pool_options = pool_options.after_connect(move |conn, _meta| { + let script = script.clone(); + Box::pin(async move { + conn.execute(script.as_str()).await?; + Ok(()) + }) + }); + } + let pool = pool_options.connect_with(options).await.map_err(|e| { + log::error!("Failed to create SQLite connection pool: {}", e); + e.to_string() + })?; log::info!( "SQLite connection pool created successfully for: {} (key: {})", diff --git a/src-tauri/src/pool_manager_tests.rs b/src-tauri/src/pool_manager_tests.rs index 6fb700f2..4c065740 100644 --- a/src-tauri/src/pool_manager_tests.rs +++ b/src-tauri/src/pool_manager_tests.rs @@ -372,3 +372,101 @@ mod postgres_tls_connector_tests { let _ = std::fs::remove_file(&file_path); } } + +#[cfg(test)] +mod startup_script_tests { + use crate::models::{ConnectionParams, DatabaseSelection}; + use crate::pool_manager::{close_pool_with_id, get_sqlite_pool_with_id}; + use tempfile::NamedTempFile; + + fn sqlite_params(path: &str, startup_script: Option<&str>) -> ConnectionParams { + ConnectionParams { + driver: "sqlite".to_string(), + database: DatabaseSelection::Single(path.to_string()), + startup_script: startup_script.map(ToOwned::to_owned), + ..Default::default() + } + } + + #[tokio::test] + async fn startup_script_runs_on_each_new_connection() { + let file = NamedTempFile::new().expect("temp file"); + let path = file.path().to_str().expect("utf8 path").to_string(); + // Unique connection id keeps this pool out of other tests' cached pools. + let conn_id = format!("startup-runs-{}", ulid::Ulid::new()); + + let params = sqlite_params( + &path, + Some( + "CREATE TABLE IF NOT EXISTS startup_marker (id INTEGER); \ + INSERT INTO startup_marker (id) VALUES (1);", + ), + ); + + let pool = get_sqlite_pool_with_id(¶ms, Some(&conn_id)) + .await + .expect("pool should be created"); + + // The marker table only exists if the startup script ran on the + // physical connection the pool just handed out. + let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM startup_marker") + .fetch_one(&pool) + .await + .expect("startup_marker table should exist"); + assert!(count >= 1, "expected at least one startup INSERT, got {count}"); + + close_pool_with_id(¶ms, Some(&conn_id)).await; + } + + #[tokio::test] + async fn blank_startup_script_is_skipped() { + let file = NamedTempFile::new().expect("temp file"); + let path = file.path().to_str().expect("utf8 path").to_string(); + let conn_id = format!("startup-blank-{}", ulid::Ulid::new()); + + // A whitespace-only script must be treated as absent: if it were run + // as SQL the connection would fail and `SELECT 1` below would error. + let params = sqlite_params(&path, Some(" \n ")); + + let pool = get_sqlite_pool_with_id(¶ms, Some(&conn_id)) + .await + .expect("pool should be created"); + + let (one,): (i64,) = sqlx::query_as("SELECT 1") + .fetch_one(&pool) + .await + .expect("query on pool with blank startup script should work"); + assert_eq!(one, 1); + + close_pool_with_id(¶ms, Some(&conn_id)).await; + } + + #[tokio::test] + async fn invalid_startup_script_surfaces_error() { + let file = NamedTempFile::new().expect("temp file"); + let path = file.path().to_str().expect("utf8 path").to_string(); + let conn_id = format!("startup-invalid-{}", ulid::Ulid::new()); + + let params = sqlite_params(&path, Some("THIS IS NOT VALID SQL;")); + + // The pool may build lazily, so the bad script can surface either at + // pool creation or on first acquire. Either way the error must reach + // the caller rather than silently succeeding. + let result = async { + let pool = get_sqlite_pool_with_id(¶ms, Some(&conn_id)).await?; + sqlx::query("SELECT 1") + .execute(&pool) + .await + .map_err(|e| e.to_string())?; + Ok::<_, String>(()) + } + .await; + + assert!( + result.is_err(), + "invalid startup script should fail the connection" + ); + + close_pool_with_id(¶ms, Some(&conn_id)).await; + } +} diff --git a/src/components/modals/NewConnectionModal.tsx b/src/components/modals/NewConnectionModal.tsx index 093a2d2e..26094c83 100644 --- a/src/components/modals/NewConnectionModal.tsx +++ b/src/components/modals/NewConnectionModal.tsx @@ -76,6 +76,8 @@ interface ConnectionParams { k8s_resource_type?: string; k8s_resource_name?: string; k8s_port?: number; + // SQL run on every new connection (e.g. SET / set_config) + startup_script?: string; } interface SavedConnection { @@ -186,7 +188,7 @@ export const NewConnectionModal = ({ // ── tab ── const [activeTab, setActiveTab] = useState< - "general" | "databases" | "ssh" | "ssl" | "k8s" | "appearance" + "general" | "databases" | "ssh" | "ssl" | "k8s" | "advanced" | "appearance" >("general"); // ── SSH ── @@ -1128,6 +1130,34 @@ export const NewConnectionModal = ({ /> ); + // ── rendered Advanced tab content (per-connection startup SQL) ── + const advancedTabContent = ( +
+ +

+ {t("newConnection.startupScriptDescription", { + defaultValue: + "SQL run on every new connection to this data source. Use it for session settings such as SET / set_config (e.g. bypassing RLS). Separate statements with semicolons.", + })} +

+