Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 267 additions & 1 deletion src-tauri/src/plugins/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use tokio::sync::{mpsc, oneshot};
use crate::drivers::driver_trait::{DatabaseDriver, PluginManifest};
use crate::models::{
ColumnDefinition, ConnectionParams, DataTypeInfo, ExplainPlan, ForeignKey, Index, QueryResult,
RoutineInfo, RoutineParameter, TableColumn, TableInfo, TableSchema, ViewInfo,
RoutineInfo, RoutineParameter, TableColumn, TableInfo, TableSchema, TriggerInfo, ViewInfo,
};
use crate::plugins::rpc::{JsonRpcRequest, JsonRpcResponse};

Expand Down Expand Up @@ -665,6 +665,81 @@ impl DatabaseDriver for RpcDriver {
Ok(())
}

async fn get_triggers(
&self,
params: &ConnectionParams,
schema: Option<&str>,
) -> Result<Vec<TriggerInfo>, String> {
let res = self
.process
.call(
"get_triggers",
json!({ "params": params, "schema": schema }),
)
.await?;
serde_json::from_value(res).map_err(|e| e.to_string())
}

async fn get_trigger_definition(
&self,
params: &ConnectionParams,
trigger_name: &str,
table_name: &str,
schema: Option<&str>,
) -> Result<String, String> {
let res = self
.process
.call(
"get_trigger_definition",
json!({
"params": params,
"trigger_name": trigger_name,
"table_name": table_name,
"schema": schema
}),
)
.await?;
serde_json::from_value(res).map_err(|e| e.to_string())
}

async fn create_trigger(
&self,
params: &ConnectionParams,
trigger_sql: &str,
schema: Option<&str>,
) -> Result<(), String> {
self
.process
.call(
"create_trigger",
json!({ "params": params, "trigger_sql": trigger_sql, "schema": schema }),
)
.await?;
Ok(())
}

async fn drop_trigger(
&self,
params: &ConnectionParams,
trigger_name: &str,
table_name: &str,
schema: Option<&str>,
) -> Result<(), String> {
self
.process
.call(
"drop_trigger",
json!({
"params": params,
"trigger_name": trigger_name,
"table_name": table_name,
"schema": schema
}),
)
.await?;
Ok(())
}

async fn get_schema_snapshot(
&self,
params: &ConnectionParams,
Expand Down Expand Up @@ -710,3 +785,194 @@ impl DatabaseDriver for RpcDriver {
serde_json::from_value(res).map_err(|e| e.to_string())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::drivers::driver_trait::DriverCapabilities;
use crate::models::DatabaseSelection;

fn test_manifest() -> PluginManifest {
PluginManifest {
id: "test-plugin".to_string(),
name: "Test Plugin".to_string(),
version: "1.0.0".to_string(),
description: "Test plugin".to_string(),
default_port: None,
capabilities: DriverCapabilities {
triggers: true,
..Default::default()
},
is_builtin: false,
default_username: String::new(),
color: String::new(),
icon: String::new(),
settings: Vec::new(),
ui_extensions: None,
}
}

fn test_connection_params() -> ConnectionParams {
ConnectionParams {
driver: "test-plugin".to_string(),
host: Some("localhost".to_string()),
port: Some(1234),
username: Some("user".to_string()),
password: Some("secret".to_string()),
database: DatabaseSelection::Single("db".to_string()),
ssl_mode: None,
ssl_ca: None,
ssl_cert: None,
ssl_key: None,
ssh_enabled: None,
ssh_connection_id: None,
ssh_host: None,
ssh_port: None,
ssh_user: None,
ssh_password: None,
ssh_key_file: None,
ssh_key_passphrase: None,
save_in_keychain: None,
k8s_enabled: None,
k8s_connection_id: None,
k8s_context: None,
k8s_namespace: None,
k8s_resource_type: None,
k8s_resource_name: None,
k8s_port: None,
connection_id: Some("conn-1".to_string()),
}
}

fn test_driver<F>(mut handle_request: F) -> RpcDriver
where
F: FnMut(JsonRpcRequest) -> Value + Send + 'static,
{
let (tx, mut rx) = mpsc::channel::<PluginCommand>(8);
tokio::spawn(async move {
while let Some(command) = rx.recv().await {
if let PluginCommand::Call(request, response_tx) = command {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
handle_request(request)
}))
.map_err(|_| "request assertion failed".to_string());
let _ = response_tx.send(result);
}
}
});

let (shutdown_tx, _shutdown_rx) = oneshot::channel();
RpcDriver {
manifest: test_manifest(),
process: Arc::new(PluginProcess {
sender: tx,
next_id: AtomicU64::new(1),
shutdown_tx: tokio::sync::Mutex::new(Some(shutdown_tx)),
pid: None,
}),
data_types: Vec::new(),
}
}

#[tokio::test]
async fn rpc_driver_forwards_get_triggers() {
let driver = test_driver(|request| {
assert_eq!(request.method, "get_triggers");
assert_eq!(request.params["schema"], "public");
assert_eq!(request.params["params"]["driver"], "test-plugin");
json!([
{
"name": "users_audit_trg",
"table_name": "users",
"event": "INSERT OR UPDATE",
"timing": "AFTER",
"definition": "CREATE TRIGGER users_audit_trg ..."
}
])
});

let triggers = driver
.get_triggers(&test_connection_params(), Some("public"))
.await
.expect("get_triggers");

assert_eq!(triggers.len(), 1);
assert_eq!(triggers[0].name, "users_audit_trg");
assert_eq!(triggers[0].table_name, "users");
assert_eq!(triggers[0].event, "INSERT OR UPDATE");
assert_eq!(triggers[0].timing, "AFTER");
assert_eq!(
triggers[0].definition.as_deref(),
Some("CREATE TRIGGER users_audit_trg ...")
);
}

#[tokio::test]
async fn rpc_driver_forwards_get_trigger_definition() {
let driver = test_driver(|request| {
assert_eq!(request.method, "get_trigger_definition");
assert_eq!(request.params["trigger_name"], "users_audit_trg");
assert_eq!(request.params["table_name"], "users");
assert_eq!(request.params["schema"], "public");
assert_eq!(request.params["params"]["driver"], "test-plugin");
json!("CREATE TRIGGER users_audit_trg ...")
});

let definition = driver
.get_trigger_definition(
&test_connection_params(),
"users_audit_trg",
"users",
Some("public"),
)
.await
.expect("get_trigger_definition");

assert_eq!(definition, "CREATE TRIGGER users_audit_trg ...");
}

#[tokio::test]
async fn rpc_driver_forwards_create_trigger() {
let driver = test_driver(|request| {
assert_eq!(request.method, "create_trigger");
assert_eq!(
request.params["trigger_sql"],
"CREATE TRIGGER users_audit_trg ..."
);
assert_eq!(request.params["schema"], "public");
assert_eq!(request.params["params"]["driver"], "test-plugin");
Value::Null
});

driver
.create_trigger(
&test_connection_params(),
"CREATE TRIGGER users_audit_trg ...",
Some("public"),
)
.await
.expect("create_trigger");
}

#[tokio::test]
async fn rpc_driver_forwards_drop_trigger() {
let driver = test_driver(|request| {
assert_eq!(request.method, "drop_trigger");
assert_eq!(request.params["trigger_name"], "users_audit_trg");
assert_eq!(request.params["table_name"], "users");
assert_eq!(request.params["schema"], "public");
assert_eq!(request.params["params"]["driver"], "test-plugin");
Value::Null
});

driver
.drop_trigger(
&test_connection_params(),
"users_audit_trg",
"users",
Some("public"),
)
.await
.expect("drop_trigger");
}
}