diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index cec36e18..d2abc88e 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -2683,17 +2683,15 @@ pub async fn delete_record( app: AppHandle, connection_id: String, table: String, - pk_col: String, - pk_val: serde_json::Value, + pk_map: std::collections::HashMap, schema: Option, database: Option, ) -> Result { log::info!( - "Executing query on connection: {} | Query: DELETE FROM {} WHERE {} = {}", + "Executing query on connection: {} | Query: DELETE FROM {} WHERE pk_map={:?}", connection_id, table, - pk_col, - pk_val + pk_map ); let saved_conn = find_connection_by_id(&app, &connection_id)?; let expanded_params = expand_ssh_connection_params(&app, &saved_conn.params).await?; @@ -2703,7 +2701,7 @@ pub async fn delete_record( params.database = crate::models::DatabaseSelection::Single(db); } let drv = driver_for(&saved_conn.params.driver).await?; - drv.delete_record(¶ms, &table, &pk_col, pk_val, schema.as_deref()) + drv.delete_record(¶ms, &table, &pk_map, schema.as_deref()) .await } @@ -2712,21 +2710,19 @@ pub async fn update_record( app: AppHandle, connection_id: String, table: String, - pk_col: String, - pk_val: serde_json::Value, + pk_map: std::collections::HashMap, col_name: String, new_val: serde_json::Value, schema: Option, database: Option, ) -> Result { log::info!( - "Executing query on connection: {} | Query: UPDATE {} SET {} = {} WHERE {} = {}", + "Executing query on connection: {} | Query: UPDATE {} SET {} = {:?} WHERE pk_map={:?}", connection_id, table, col_name, new_val, - pk_col, - pk_val + pk_map ); let saved_conn = find_connection_by_id(&app, &connection_id)?; let expanded_params = expand_ssh_connection_params(&app, &saved_conn.params).await?; @@ -2740,8 +2736,7 @@ pub async fn update_record( drv.update_record( ¶ms, &table, - &pk_col, - pk_val, + &pk_map, &col_name, new_val, schema.as_deref(), @@ -2756,8 +2751,7 @@ pub async fn save_blob_to_file( connection_id: String, table: String, col_name: String, - pk_col: String, - pk_val: serde_json::Value, + pk_map: std::collections::HashMap, file_path: String, schema: Option, ) -> Result<(), String> { @@ -2770,8 +2764,7 @@ pub async fn save_blob_to_file( ¶ms, &table, &col_name, - &pk_col, - pk_val, + &pk_map, schema.as_deref(), &file_path, ) @@ -2786,8 +2779,7 @@ pub async fn fetch_blob_as_data_url( connection_id: String, table: String, col_name: String, - pk_col: String, - pk_val: serde_json::Value, + pk_map: std::collections::HashMap, schema: Option, ) -> Result { let saved_conn = find_connection_by_id(&app, &connection_id)?; @@ -2800,8 +2792,7 @@ pub async fn fetch_blob_as_data_url( ¶ms, &table, &col_name, - &pk_col, - pk_val, + &pk_map, schema.as_deref(), ) .await?; diff --git a/src-tauri/src/drivers/driver_trait.rs b/src-tauri/src/drivers/driver_trait.rs index f92245a3..e66e4f6d 100644 --- a/src-tauri/src/drivers/driver_trait.rs +++ b/src-tauri/src/drivers/driver_trait.rs @@ -435,8 +435,7 @@ pub trait DatabaseDriver: Send + Sync { &self, params: &ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, col_name: &str, new_val: serde_json::Value, schema: Option<&str>, @@ -447,8 +446,7 @@ pub trait DatabaseDriver: Send + Sync { &self, params: &ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, schema: Option<&str>, ) -> Result; @@ -459,8 +457,7 @@ pub trait DatabaseDriver: Send + Sync { _params: &ConnectionParams, _table: &str, _col_name: &str, - _pk_col: &str, - _pk_val: serde_json::Value, + _pk_map: &std::collections::HashMap, _schema: Option<&str>, _file_path: &str, ) -> Result<(), String> { @@ -472,8 +469,7 @@ pub trait DatabaseDriver: Send + Sync { _params: &ConnectionParams, _table: &str, _col_name: &str, - _pk_col: &str, - _pk_val: serde_json::Value, + _pk_map: &std::collections::HashMap, _schema: Option<&str>, ) -> Result { Err("BLOB preview not supported by this driver".into()) diff --git a/src-tauri/src/drivers/mysql/mod.rs b/src-tauri/src/drivers/mysql/mod.rs index a01e5591..ce673f38 100644 --- a/src-tauri/src/drivers/mysql/mod.rs +++ b/src-tauri/src/drivers/mysql/mod.rs @@ -321,39 +321,39 @@ pub async fn get_indexes( .collect()) } +/// Sort the pk_map into a deterministic (col, val) vec for use with QueryBuilder. +fn build_mysql_pk_where( + pk_map: &HashMap, +) -> Result, String> { + if pk_map.is_empty() { + return Err("pk_map must not be empty".into()); + } + let mut pairs: Vec<(String, serde_json::Value)> = pk_map + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + pairs.sort_by(|a, b| a.0.cmp(&b.0)); + Ok(pairs) +} + + pub async fn save_blob_column_to_file( params: &ConnectionParams, table: &str, col_name: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &HashMap, file_path: &str, ) -> Result<(), String> { - let pool = get_mysql_pool(params).await?; - - let query = format!( - "SELECT `{}` FROM `{}` WHERE `{}` = ?", - col_name, table, pk_col - ); - - let row = match pk_val { - serde_json::Value::Number(n) => { - if n.is_i64() { - sqlx::query(&query).bind(n.as_i64()).fetch_one(&pool).await - } else if n.is_f64() { - sqlx::query(&query).bind(n.as_f64()).fetch_one(&pool).await - } else { - sqlx::query(&query) - .bind(n.to_string()) - .fetch_one(&pool) - .await - } - } - serde_json::Value::String(s) => sqlx::query(&query).bind(s).fetch_one(&pool).await, - _ => return Err("Unsupported PK type".into()), - } - .map_err(|e| e.to_string())?; - + let row = mysql_fetch_one_with_pk( + params, + &format!( + "SELECT `{}` FROM `{}`", + escape_identifier(col_name), + escape_identifier(table) + ), + pk_map, + ) + .await?; let bytes: Vec = row.try_get(0).map_err(|e| e.to_string())?; std::fs::write(file_path, bytes).map_err(|e| e.to_string()) } @@ -362,77 +362,131 @@ pub async fn fetch_blob_column_as_data_url( params: &ConnectionParams, table: &str, col_name: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &HashMap, ) -> Result { - let pool = get_mysql_pool(params).await?; - - let query = format!( - "SELECT `{}` FROM `{}` WHERE `{}` = ?", - col_name, table, pk_col - ); + let row = mysql_fetch_one_with_pk( + params, + &format!( + "SELECT `{}` FROM `{}`", + escape_identifier(col_name), + escape_identifier(table) + ), + pk_map, + ) + .await?; + let bytes: Vec = row.try_get(0).map_err(|e| e.to_string())?; + Ok(crate::drivers::common::encode_blob_full(&bytes)) +} - let row = match pk_val { - serde_json::Value::Number(n) => { - if n.is_i64() { - sqlx::query(&query).bind(n.as_i64()).fetch_one(&pool).await - } else if n.is_f64() { - sqlx::query(&query).bind(n.as_f64()).fetch_one(&pool).await - } else { - sqlx::query(&query) - .bind(n.to_string()) - .fetch_one(&pool) - .await +/// Execute a SELECT query appending a WHERE clause built from pk_map and return the first row. +async fn mysql_fetch_one_with_pk( + params: &ConnectionParams, + select_from: &str, + pk_map: &HashMap, +) -> Result { + let pool = get_mysql_pool(params).await?; + let pairs = build_mysql_pk_where(pk_map)?; + let mut first = true; + let mut qb3 = sqlx::QueryBuilder::::new(format!("{} WHERE ", select_from)); + for (col, val) in &pairs { + if !first { + qb3.push(" AND "); + } + qb3.push(format!("`{}` = ", escape_identifier(col))); + match val { + serde_json::Value::Number(n) => { + if n.is_i64() { + qb3.push_bind(n.as_i64()); + } else if n.is_f64() { + qb3.push_bind(n.as_f64()); + } else { + qb3.push_bind(n.to_string()); + } } + serde_json::Value::String(s) => { + if let Some(n) = parse_unsafe_bigint_string(s) { + qb3.push_bind(n); + } else { + qb3.push_bind(s.clone()); + } + } + _ => return Err("Unsupported PK type".into()), } - serde_json::Value::String(s) => sqlx::query(&query).bind(s).fetch_one(&pool).await, - _ => return Err("Unsupported PK type".into()), + first = false; } - .map_err(|e| e.to_string())?; - - let bytes: Vec = row.try_get(0).map_err(|e| e.to_string())?; - Ok(crate::drivers::common::encode_blob_full(&bytes)) + qb3.build().fetch_one(&pool).await.map_err(|e| e.to_string()) } -pub async fn delete_record( +/// Execute a DELETE/UPDATE query appending a WHERE clause from pk_map. +/// Returns the number of affected rows. +async fn mysql_execute_with_pk( params: &ConnectionParams, - table: &str, - pk_col: &str, - pk_val: serde_json::Value, + prefix: &str, + pk_map: &HashMap, ) -> Result { let pool = get_mysql_pool(params).await?; - - let query = format!("DELETE FROM `{}` WHERE `{}` = ?", table, pk_col); - - let result = match pk_val { - serde_json::Value::Number(n) => { - if n.is_i64() { - sqlx::query(&query).bind(n.as_i64()).execute(&pool).await - } else if n.is_f64() { - sqlx::query(&query).bind(n.as_f64()).execute(&pool).await - } else { - sqlx::query(&query).bind(n.to_string()).execute(&pool).await + let pairs = build_mysql_pk_where(pk_map)?; + let mut qb = sqlx::QueryBuilder::::new(format!("{} WHERE ", prefix)); + let mut first = true; + for (col, val) in &pairs { + if !first { + qb.push(" AND "); + } + qb.push(format!("`{}` = ", escape_identifier(col))); + match val { + serde_json::Value::Number(n) => { + if n.is_i64() { + qb.push_bind(n.as_i64()); + } else if n.is_f64() { + qb.push_bind(n.as_f64()); + } else { + qb.push_bind(n.to_string()); + } } + serde_json::Value::String(s) => { + if let Some(n) = parse_unsafe_bigint_string(s) { + qb.push_bind(n); + } else { + qb.push_bind(s.clone()); + } + } + _ => return Err("Unsupported PK type".into()), } - serde_json::Value::String(s) => sqlx::query(&query).bind(s).execute(&pool).await, - _ => return Err("Unsupported PK type".into()), - }; + first = false; + } + let result = qb.build().execute(&pool).await.map_err(|e| e.to_string())?; + Ok(result.rows_affected()) +} - result.map(|r| r.rows_affected()).map_err(|e| e.to_string()) +pub async fn delete_record( + params: &ConnectionParams, + table: &str, + pk_map: &HashMap, +) -> Result { + mysql_execute_with_pk( + params, + &format!("DELETE FROM `{}`", escape_identifier(table)), + pk_map, + ) + .await } pub async fn update_record( params: &ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &HashMap, col_name: &str, new_val: serde_json::Value, max_blob_size: u64, ) -> Result { let pool = get_mysql_pool(params).await?; + let pk_pairs = build_mysql_pk_where(pk_map)?; - let mut qb = sqlx::QueryBuilder::new(format!("UPDATE `{}` SET `{}` = ", table, col_name)); + let mut qb = sqlx::QueryBuilder::new(format!( + "UPDATE `{}` SET `{}` = ", + escape_identifier(table), + escape_identifier(col_name) + )); match new_val { serde_json::Value::Number(n) => { @@ -443,28 +497,19 @@ pub async fn update_record( } } serde_json::Value::String(s) => { - // Check for special sentinel value to use DEFAULT if s == "__USE_DEFAULT__" { qb.push("DEFAULT"); } else if let Some(bytes) = crate::drivers::common::decode_blob_wire_format(&s, max_blob_size) { - // Blob wire format: decode to raw bytes so the DB stores binary data, - // not the internal wire format string. qb.push_bind(bytes); } else if is_raw_sql_function(&s) { - // If it's a raw SQL function (e.g., ST_GeomFromText('POINT(1 2)', 4326)) - // insert it directly without parameter binding qb.push(s); } else if is_wkt_geometry(&s) { - // If it's WKT geometry format, wrap with ST_GeomFromText qb.push("ST_GeomFromText("); qb.push_bind(s); qb.push(")"); } else if let Some(n) = parse_unsafe_bigint_string(&s) { - // Bigints outside JS safe range come back from the UI as strings - // (see drivers::common::i64_to_json). Bind them as native i64 so - // BIGINT columns receive the exact value. qb.push_bind(n); } else { qb.push_bind(s); @@ -484,28 +529,36 @@ pub async fn update_record( } } - qb.push(format!(" WHERE `{}` = ", pk_col)); - - match pk_val { - serde_json::Value::Number(n) => { - if n.is_i64() { - qb.push_bind(n.as_i64()); - } else { - qb.push_bind(n.as_f64()); - } + qb.push(" WHERE "); + let mut first = true; + for (col, val) in &pk_pairs { + if !first { + qb.push(" AND "); } - serde_json::Value::String(s) => { - if let Some(n) = parse_unsafe_bigint_string(&s) { - qb.push_bind(n); - } else { - qb.push_bind(s); + qb.push(format!("`{}` = ", escape_identifier(col))); + match val { + serde_json::Value::Number(n) => { + if n.is_i64() { + qb.push_bind(n.as_i64()); + } else if n.is_f64() { + qb.push_bind(n.as_f64()); + } else { + qb.push_bind(n.to_string()); + } + } + serde_json::Value::String(s) => { + if let Some(n) = parse_unsafe_bigint_string(s) { + qb.push_bind(n); + } else { + qb.push_bind(s.clone()); + } } + _ => return Err("Unsupported PK type".into()), } - _ => return Err("Unsupported PK type".into()), + first = false; } - let query = qb.build(); - let result = query.execute(&pool).await.map_err(|e| e.to_string())?; + let result = qb.build().execute(&pool).await.map_err(|e| e.to_string())?; Ok(result.rows_affected()) } @@ -1571,34 +1624,23 @@ impl DatabaseDriver for MysqlDriver { &self, params: &crate::models::ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, col_name: &str, new_val: serde_json::Value, _schema: Option<&str>, max_blob_size: u64, ) -> Result { - update_record( - params, - table, - pk_col, - pk_val, - col_name, - new_val, - max_blob_size, - ) - .await + update_record(params, table, pk_map, col_name, new_val, max_blob_size).await } async fn delete_record( &self, params: &crate::models::ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, _schema: Option<&str>, ) -> Result { - delete_record(params, table, pk_col, pk_val).await + delete_record(params, table, pk_map).await } async fn save_blob_to_file( @@ -1606,12 +1648,11 @@ impl DatabaseDriver for MysqlDriver { params: &crate::models::ConnectionParams, table: &str, col_name: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, _schema: Option<&str>, file_path: &str, ) -> Result<(), String> { - save_blob_column_to_file(params, table, col_name, pk_col, pk_val, file_path).await + save_blob_column_to_file(params, table, col_name, pk_map, file_path).await } async fn fetch_blob_as_data_url( @@ -1619,11 +1660,10 @@ impl DatabaseDriver for MysqlDriver { params: &crate::models::ConnectionParams, table: &str, col_name: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, _schema: Option<&str>, ) -> Result { - fetch_blob_column_as_data_url(params, table, col_name, pk_col, pk_val).await + fetch_blob_column_as_data_url(params, table, col_name, pk_map).await } async fn get_create_table_sql( diff --git a/src-tauri/src/drivers/mysql/tests.rs b/src-tauri/src/drivers/mysql/tests.rs index e26dc835..384bfb03 100644 --- a/src-tauri/src/drivers/mysql/tests.rs +++ b/src-tauri/src/drivers/mysql/tests.rs @@ -1,3 +1,4 @@ +use super::build_mysql_pk_where; use super::explain::{parse_analyze_actual, parse_mysql_analyze_text, parse_mysql_query_block}; use super::MysqlDriver; use crate::drivers::driver_trait::DatabaseDriver; @@ -601,3 +602,34 @@ fn parse_mysql_analyze_text_reports_total_time_for_looped_node() { "expected ~2646ms total for index lookup, got {total}" ); } + +mod build_mysql_pk_where_tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn single_column_returns_correct_pair() { + let mut pk_map = HashMap::new(); + pk_map.insert("id".to_string(), serde_json::json!(42)); + let pairs = build_mysql_pk_where(&pk_map).unwrap(); + assert_eq!(pairs.len(), 1); + assert_eq!(pairs[0].0, "id"); + assert_eq!(pairs[0].1, serde_json::json!(42)); + } + + #[test] + fn composite_pk_columns_are_sorted_alphabetically() { + let mut pk_map = HashMap::new(); + pk_map.insert("z_col".to_string(), serde_json::json!(1)); + pk_map.insert("a_col".to_string(), serde_json::json!(2)); + let pairs = build_mysql_pk_where(&pk_map).unwrap(); + assert_eq!(pairs[0].0, "a_col"); + assert_eq!(pairs[1].0, "z_col"); + } + + #[test] + fn empty_pk_map_is_rejected() { + let pk_map: HashMap = HashMap::new(); + assert!(build_mysql_pk_where(&pk_map).is_err()); + } +} diff --git a/src-tauri/src/drivers/postgres/binding.rs b/src-tauri/src/drivers/postgres/binding.rs index a3652479..405fb66d 100644 --- a/src-tauri/src/drivers/postgres/binding.rs +++ b/src-tauri/src/drivers/postgres/binding.rs @@ -3,6 +3,7 @@ use super::helpers::{ json_array_to_pg_literal, try_parse_pg_array, }; use crate::drivers::common::parse_unsafe_bigint_string; +use std::collections::HashMap; use tokio_postgres::types::ToSql; pub(super) type PgParam = Box; @@ -55,6 +56,29 @@ pub(super) fn build_pk_predicate( } } +/// Build a compound WHERE predicate from all entries of a pk_map. +/// Keys are sorted for determinism. Returns the predicate string and all boxed params. +/// E.g. `"col1" = $2 AND "col2" = $3` with params starting at placeholder_idx. +pub(super) fn build_pk_map_predicate( + pk_map: &HashMap, + placeholder_idx: usize, +) -> Result<(String, Vec), String> { + if pk_map.is_empty() { + return Err("pk_map must not be empty".into()); + } + let mut keys: Vec<&String> = pk_map.keys().collect(); + keys.sort(); + let mut predicates = Vec::new(); + let mut params: Vec = Vec::new(); + for key in keys { + let val = pk_map[key].clone(); + let (pred, param) = build_pk_predicate(key, val, placeholder_idx + params.len())?; + predicates.push(pred); + params.push(param); + } + Ok((predicates.join(" AND "), params)) +} + pub(super) fn bind_pg_value( value: serde_json::Value, placeholder_idx: usize, diff --git a/src-tauri/src/drivers/postgres/mod.rs b/src-tauri/src/drivers/postgres/mod.rs index 50ba7480..f3f73f73 100644 --- a/src-tauri/src/drivers/postgres/mod.rs +++ b/src-tauri/src/drivers/postgres/mod.rs @@ -16,7 +16,7 @@ use crate::models::{ TableColumn, TableInfo, TriggerInfo, ViewInfo, }; use crate::pool_manager::get_postgres_pool; -use binding::{PgValueOptions, bind_pg_value, build_pk_predicate}; +use binding::{PgValueOptions, bind_pg_value, build_pk_map_predicate}; use client::{execute, format_pg_error, get_client, query_all, query_one}; pub use explain::explain_query; use extract::extract_value; @@ -415,14 +415,13 @@ pub async fn save_blob_column_to_file( params: &ConnectionParams, table: &str, col_name: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &HashMap, schema: &str, file_path: &str, ) -> Result<(), String> { let pool = get_postgres_pool(params).await?; - let (predicate, param) = build_pk_predicate(pk_col, pk_val, 1)?; + let (predicate, pk_params) = build_pk_map_predicate(pk_map, 1)?; let query = format!( "SELECT \"{}\" FROM \"{}\".\"{}\" WHERE {}", escape_identifier(col_name), @@ -431,7 +430,11 @@ pub async fn save_blob_column_to_file( predicate, ); - let row = query_one(&pool, &query, &[param.as_ref() as &(dyn ToSql + Sync)]).await?; + let params_ref: Vec<&(dyn ToSql + Sync)> = pk_params + .iter() + .map(|b| b.as_ref() as &(dyn ToSql + Sync)) + .collect(); + let row = query_one(&pool, &query, ¶ms_ref).await?; let bytes: Vec = row.try_get(0).map_err(|e| format_pg_error(&e))?; std::fs::write(file_path, bytes).map_err(|e| e.to_string()) @@ -441,13 +444,12 @@ pub async fn fetch_blob_column_as_data_url( params: &ConnectionParams, table: &str, col_name: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &HashMap, schema: &str, ) -> Result { let pool = get_postgres_pool(params).await?; - let (predicate, param) = build_pk_predicate(pk_col, pk_val, 1)?; + let (predicate, pk_params) = build_pk_map_predicate(pk_map, 1)?; let query = format!( "SELECT \"{}\" FROM \"{}\".\"{}\" WHERE {}", escape_identifier(col_name), @@ -456,7 +458,11 @@ pub async fn fetch_blob_column_as_data_url( predicate, ); - let row = query_one(&pool, &query, &[param.as_ref() as &(dyn ToSql + Sync)]).await?; + let params_ref: Vec<&(dyn ToSql + Sync)> = pk_params + .iter() + .map(|b| b.as_ref() as &(dyn ToSql + Sync)) + .collect(); + let row = query_one(&pool, &query, ¶ms_ref).await?; let bytes: Vec = row.try_get(0).map_err(|e| format_pg_error(&e))?; Ok(crate::drivers::common::encode_blob_full(&bytes)) @@ -541,13 +547,12 @@ fn update_record_error_context( pub async fn delete_record( params: &ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &HashMap, schema: &str, ) -> Result { let pool = get_postgres_pool(params).await?; - let (predicate, param) = build_pk_predicate(pk_col, pk_val, 1)?; + let (predicate, pk_params) = build_pk_map_predicate(pk_map, 1)?; let query = format!( "DELETE FROM \"{}\".\"{}\" WHERE {}", escape_identifier(schema), @@ -555,14 +560,17 @@ pub async fn delete_record( predicate, ); - execute(&pool, &query, &[param.as_ref() as &(dyn ToSql + Sync)]).await + let params_ref: Vec<&(dyn ToSql + Sync)> = pk_params + .iter() + .map(|b| b.as_ref() as &(dyn ToSql + Sync)) + .collect(); + execute(&pool, &query, ¶ms_ref).await } pub async fn update_record( params: &ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &HashMap, col_name: &str, new_val: serde_json::Value, schema: &str, @@ -583,7 +591,7 @@ pub async fn update_record( } }; let new_val_for_context = new_val.clone(); - let pk_val_for_context = pk_val.clone(); + let pk_map_for_context = pk_map.clone(); let mut query = format!( "UPDATE \"{}\".\"{}\" SET \"{}\" = ", @@ -592,11 +600,11 @@ pub async fn update_record( escape_identifier(col_name) ); - let mut params: Vec> = Vec::new(); + let mut bound_params: Vec> = Vec::new(); let bound = bind_pg_value( new_val, - params.len() + 1, + bound_params.len() + 1, PgValueOptions { column_type: column_data_type.as_deref(), max_blob_size, @@ -605,26 +613,36 @@ pub async fn update_record( )?; query.push_str(&bound.sql); if let Some(param) = bound.param { - params.push(param); + bound_params.push(param); } - let (predicate, pk_param) = build_pk_predicate(pk_col, pk_val, params.len() + 1)?; + let (predicate, pk_params) = build_pk_map_predicate(pk_map, bound_params.len() + 1)?; query.push_str(" WHERE "); query.push_str(&predicate); - params.push(pk_param); + bound_params.extend(pk_params); - let params: Vec<&(dyn ToSql + Sync)> = params + let params_ref: Vec<&(dyn ToSql + Sync)> = bound_params .iter() .map(|b| b.as_ref() as &(dyn ToSql + Sync)) .collect(); - execute(&pool, &query, ¶ms).await.map_err(|err| { + let first_pk_col = { + let mut keys: Vec<&String> = pk_map_for_context.keys().collect(); + keys.sort(); + keys.first().map(|k| k.as_str()).unwrap_or("") + }; + let first_pk_val = pk_map_for_context + .get(first_pk_col) + .cloned() + .unwrap_or(serde_json::Value::Null); + + execute(&pool, &query, ¶ms_ref).await.map_err(|err| { update_record_error_context( err, schema, table, - pk_col, - &pk_val_for_context, + first_pk_col, + &first_pk_val, col_name, &new_val_for_context, column_data_type.as_deref(), @@ -1670,8 +1688,7 @@ impl DatabaseDriver for PostgresDriver { &self, params: &crate::models::ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, col_name: &str, new_val: serde_json::Value, schema: Option<&str>, @@ -1680,8 +1697,7 @@ impl DatabaseDriver for PostgresDriver { update_record( params, table, - pk_col, - pk_val, + pk_map, col_name, new_val, self.resolve_schema(schema), @@ -1694,11 +1710,10 @@ impl DatabaseDriver for PostgresDriver { &self, params: &crate::models::ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, schema: Option<&str>, ) -> Result { - delete_record(params, table, pk_col, pk_val, self.resolve_schema(schema)).await + delete_record(params, table, pk_map, self.resolve_schema(schema)).await } async fn save_blob_to_file( @@ -1706,8 +1721,7 @@ impl DatabaseDriver for PostgresDriver { params: &crate::models::ConnectionParams, table: &str, col_name: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, schema: Option<&str>, file_path: &str, ) -> Result<(), String> { @@ -1715,8 +1729,7 @@ impl DatabaseDriver for PostgresDriver { params, table, col_name, - pk_col, - pk_val, + pk_map, self.resolve_schema(schema), file_path, ) @@ -1728,16 +1741,14 @@ impl DatabaseDriver for PostgresDriver { params: &crate::models::ConnectionParams, table: &str, col_name: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, schema: Option<&str>, ) -> Result { fetch_blob_column_as_data_url( params, table, col_name, - pk_col, - pk_val, + pk_map, self.resolve_schema(schema), ) .await diff --git a/src-tauri/src/drivers/postgres/tests.rs b/src-tauri/src/drivers/postgres/tests.rs index c0635746..d94733f3 100644 --- a/src-tauri/src/drivers/postgres/tests.rs +++ b/src-tauri/src/drivers/postgres/tests.rs @@ -1,6 +1,6 @@ use super::binding::{ PgValueOptions, bind_pg_boolean_string, bind_pg_number, bind_pg_numeric_string, bind_pg_value, - build_pk_predicate, + build_pk_map_predicate, build_pk_predicate, }; use super::helpers::{extract_base_type, is_implicit_cast_compatible}; @@ -536,3 +536,51 @@ mod build_pk_predicate_tests { assert!(build_pk_predicate("id", serde_json::json!(true), 1).is_err()); } } + +mod build_pk_map_predicate_tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn single_integer_column() { + let mut pk_map = HashMap::new(); + pk_map.insert("id".to_string(), serde_json::json!(1)); + let (sql, params) = build_pk_map_predicate(&pk_map, 1).unwrap(); + assert_eq!(sql, "\"id\" = CAST($1 AS bigint)"); + assert_eq!(params.len(), 1); + } + + #[test] + fn composite_pk_sorted_alphabetically_with_consecutive_placeholders() { + let mut pk_map = HashMap::new(); + pk_map.insert("z_col".to_string(), serde_json::json!("alice")); + pk_map.insert("a_col".to_string(), serde_json::json!("bob")); + let (sql, params) = build_pk_map_predicate(&pk_map, 1).unwrap(); + assert_eq!(sql, "\"a_col\" = $1 AND \"z_col\" = $2"); + assert_eq!(params.len(), 2); + } + + #[test] + fn non_one_starting_placeholder_idx() { + let mut pk_map = HashMap::new(); + pk_map.insert("id".to_string(), serde_json::json!(5)); + let (sql, _) = build_pk_map_predicate(&pk_map, 3).unwrap(); + assert_eq!(sql, "\"id\" = CAST($3 AS bigint)"); + } + + #[test] + fn composite_pk_with_mixed_types() { + let mut pk_map = HashMap::new(); + pk_map.insert("b_col".to_string(), serde_json::json!("alice")); + pk_map.insert("a_col".to_string(), serde_json::json!(99)); + let (sql, params) = build_pk_map_predicate(&pk_map, 1).unwrap(); + assert_eq!(sql, "\"a_col\" = CAST($1 AS bigint) AND \"b_col\" = $2"); + assert_eq!(params.len(), 2); + } + + #[test] + fn empty_pk_map_is_rejected() { + let pk_map: HashMap = HashMap::new(); + assert!(build_pk_map_predicate(&pk_map, 1).is_err()); + } +} diff --git a/src-tauri/src/drivers/sqlite/mod.rs b/src-tauri/src/drivers/sqlite/mod.rs index 2d0d405c..c97c3703 100644 --- a/src-tauri/src/drivers/sqlite/mod.rs +++ b/src-tauri/src/drivers/sqlite/mod.rs @@ -267,34 +267,70 @@ pub async fn get_indexes( Ok(result) } -pub async fn save_blob_column_to_file( - params: &ConnectionParams, - table: &str, - col_name: &str, - pk_col: &str, - pk_val: serde_json::Value, - file_path: &str, +fn sqlite_push_pk_val( + qb: &mut sqlx::QueryBuilder, + val: &serde_json::Value, ) -> Result<(), String> { - let pool = get_sqlite_pool(params).await?; - - let query = format!( - "SELECT \"{}\" FROM \"{}\" WHERE \"{}\" = ?", - col_name, table, pk_col - ); - - let row = match pk_val { + match val { serde_json::Value::Number(n) => { if n.is_i64() { - sqlx::query(&query).bind(n.as_i64()).fetch_one(&pool).await + qb.push_bind(n.as_i64()); + } else { + qb.push_bind(n.as_f64()); + } + } + serde_json::Value::String(s) => { + if let Some(n) = parse_unsafe_bigint_string(s) { + qb.push_bind(n); } else { - sqlx::query(&query).bind(n.as_f64()).fetch_one(&pool).await + qb.push_bind(s.clone()); } } - serde_json::Value::String(s) => sqlx::query(&query).bind(s).fetch_one(&pool).await, _ => return Err("Unsupported PK type".into()), } - .map_err(|e| e.to_string())?; + Ok(()) +} +fn sqlite_push_pk_where( + qb: &mut sqlx::QueryBuilder, + pk_map: &HashMap, +) -> Result<(), String> { + if pk_map.is_empty() { + return Err("pk_map must not be empty".into()); + } + let mut pairs: Vec<(&String, &serde_json::Value)> = pk_map.iter().collect(); + pairs.sort_by_key(|(k, _)| k.as_str()); + let mut first = true; + for (col, val) in &pairs { + if !first { + qb.push(" AND "); + } + qb.push(format!("\"{}\" = ", escape_identifier(col))); + sqlite_push_pk_val(qb, val)?; + first = false; + } + Ok(()) +} + +pub async fn save_blob_column_to_file( + params: &ConnectionParams, + table: &str, + col_name: &str, + pk_map: &HashMap, + file_path: &str, +) -> Result<(), String> { + let pool = get_sqlite_pool(params).await?; + let mut qb = sqlx::QueryBuilder::::new(format!( + "SELECT \"{}\" FROM \"{}\" WHERE ", + escape_identifier(col_name), + escape_identifier(table) + )); + sqlite_push_pk_where(&mut qb, pk_map)?; + let row = qb + .build() + .fetch_one(&pool) + .await + .map_err(|e| e.to_string())?; let bytes: Vec = row.try_get(0).map_err(|e| e.to_string())?; std::fs::write(file_path, bytes).map_err(|e| e.to_string()) } @@ -303,29 +339,20 @@ pub async fn fetch_blob_column_as_data_url( params: &ConnectionParams, table: &str, col_name: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &HashMap, ) -> Result { let pool = get_sqlite_pool(params).await?; - - let query = format!( - "SELECT \"{}\" FROM \"{}\" WHERE \"{}\" = ?", - col_name, table, pk_col - ); - - let row = match pk_val { - serde_json::Value::Number(n) => { - if n.is_i64() { - sqlx::query(&query).bind(n.as_i64()).fetch_one(&pool).await - } else { - sqlx::query(&query).bind(n.as_f64()).fetch_one(&pool).await - } - } - serde_json::Value::String(s) => sqlx::query(&query).bind(s).fetch_one(&pool).await, - _ => return Err("Unsupported PK type".into()), - } - .map_err(|e| e.to_string())?; - + let mut qb = sqlx::QueryBuilder::::new(format!( + "SELECT \"{}\" FROM \"{}\" WHERE ", + escape_identifier(col_name), + escape_identifier(table) + )); + sqlite_push_pk_where(&mut qb, pk_map)?; + let row = qb + .build() + .fetch_one(&pool) + .await + .map_err(|e| e.to_string())?; let bytes: Vec = row.try_get(0).map_err(|e| e.to_string())?; Ok(crate::drivers::common::encode_blob_full(&bytes)) } @@ -333,40 +360,37 @@ pub async fn fetch_blob_column_as_data_url( pub async fn delete_record( params: &ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &HashMap, ) -> Result { let pool = get_sqlite_pool(params).await?; - - let query = format!("DELETE FROM \"{}\" WHERE \"{}\" = ?", table, pk_col); - - let result = match pk_val { - serde_json::Value::Number(n) => { - if n.is_i64() { - sqlx::query(&query).bind(n.as_i64()).execute(&pool).await - } else { - sqlx::query(&query).bind(n.as_f64()).execute(&pool).await - } - } - serde_json::Value::String(s) => sqlx::query(&query).bind(s).execute(&pool).await, - _ => return Err("Unsupported PK type".into()), - }; - - result.map(|r| r.rows_affected()).map_err(|e| e.to_string()) + let mut qb = sqlx::QueryBuilder::::new(format!( + "DELETE FROM \"{}\" WHERE ", + escape_identifier(table) + )); + sqlite_push_pk_where(&mut qb, pk_map)?; + let result = qb + .build() + .execute(&pool) + .await + .map_err(|e| e.to_string())?; + Ok(result.rows_affected()) } pub async fn update_record( params: &ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &HashMap, col_name: &str, new_val: serde_json::Value, max_blob_size: u64, ) -> Result { let pool = get_sqlite_pool(params).await?; - let mut qb = sqlx::QueryBuilder::new(format!("UPDATE \"{}\" SET \"{}\" = ", table, col_name)); + let mut qb = sqlx::QueryBuilder::new(format!( + "UPDATE \"{}\" SET \"{}\" = ", + escape_identifier(table), + escape_identifier(col_name) + )); match new_val { serde_json::Value::Number(n) => { @@ -377,18 +401,13 @@ pub async fn update_record( } } serde_json::Value::String(s) => { - // Check for special sentinel value to use DEFAULT if s == "__USE_DEFAULT__" { qb.push("DEFAULT"); } else if let Some(bytes) = crate::drivers::common::decode_blob_wire_format(&s, max_blob_size) { - // Blob wire format: decode to raw bytes so the DB stores binary data, - // not the internal wire format string. qb.push_bind(bytes); } else if let Some(n) = parse_unsafe_bigint_string(&s) { - // Bigints outside JS safe range come back from the UI as strings - // (see drivers::common::i64_to_json). Bind them as native i64. qb.push_bind(n); } else { qb.push_bind(s); @@ -403,28 +422,10 @@ pub async fn update_record( _ => return Err("Unsupported Value type".into()), } - qb.push(format!(" WHERE \"{}\" = ", pk_col)); + qb.push(" WHERE "); + sqlite_push_pk_where(&mut qb, pk_map)?; - match pk_val { - serde_json::Value::Number(n) => { - if n.is_i64() { - qb.push_bind(n.as_i64()); - } else { - qb.push_bind(n.as_f64()); - } - } - serde_json::Value::String(s) => { - if let Some(n) = parse_unsafe_bigint_string(&s) { - qb.push_bind(n); - } else { - qb.push_bind(s); - } - } - _ => return Err("Unsupported PK type".into()), - } - - let query = qb.build(); - let result = query.execute(&pool).await.map_err(|e| e.to_string())?; + let result = qb.build().execute(&pool).await.map_err(|e| e.to_string())?; Ok(result.rows_affected()) } @@ -1177,34 +1178,23 @@ impl DatabaseDriver for SqliteDriver { &self, params: &crate::models::ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, col_name: &str, new_val: serde_json::Value, _schema: Option<&str>, max_blob_size: u64, ) -> Result { - update_record( - params, - table, - pk_col, - pk_val, - col_name, - new_val, - max_blob_size, - ) - .await + update_record(params, table, pk_map, col_name, new_val, max_blob_size).await } async fn delete_record( &self, params: &crate::models::ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, _schema: Option<&str>, ) -> Result { - delete_record(params, table, pk_col, pk_val).await + delete_record(params, table, pk_map).await } async fn save_blob_to_file( @@ -1212,12 +1202,11 @@ impl DatabaseDriver for SqliteDriver { params: &crate::models::ConnectionParams, table: &str, col_name: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, _schema: Option<&str>, file_path: &str, ) -> Result<(), String> { - save_blob_column_to_file(params, table, col_name, pk_col, pk_val, file_path).await + save_blob_column_to_file(params, table, col_name, pk_map, file_path).await } async fn fetch_blob_as_data_url( @@ -1225,11 +1214,10 @@ impl DatabaseDriver for SqliteDriver { params: &crate::models::ConnectionParams, table: &str, col_name: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, _schema: Option<&str>, ) -> Result { - fetch_blob_column_as_data_url(params, table, col_name, pk_col, pk_val).await + fetch_blob_column_as_data_url(params, table, col_name, pk_map).await } async fn get_create_table_sql( diff --git a/src-tauri/src/drivers/sqlite/tests.rs b/src-tauri/src/drivers/sqlite/tests.rs index 36c8e430..b5230cf8 100644 --- a/src-tauri/src/drivers/sqlite/tests.rs +++ b/src-tauri/src/drivers/sqlite/tests.rs @@ -1,4 +1,5 @@ use super::explain::{build_sqlite_tree, parse_sqlite_detail}; +use super::sqlite_push_pk_where; use super::{alter_view, create_view, drop_view, get_view_columns, get_view_definition, get_views}; use crate::models::{ConnectionParams, DatabaseSelection}; use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; @@ -150,3 +151,43 @@ async fn test_view_lifecycle() { // Cleanup: Close the pool created by the functions (via pool_manager) crate::pool_manager::close_pool(¶ms).await; } + +mod sqlite_push_pk_where_tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn single_column_generates_correct_predicate() { + let mut pk_map = HashMap::new(); + pk_map.insert("id".to_string(), serde_json::json!(42)); + let mut qb = sqlx::QueryBuilder::::new(""); + sqlite_push_pk_where(&mut qb, &pk_map).unwrap(); + assert_eq!(qb.sql(), "\"id\" = ?"); + } + + #[test] + fn composite_pk_columns_are_sorted_alphabetically() { + let mut pk_map = HashMap::new(); + pk_map.insert("z_col".to_string(), serde_json::json!(1)); + pk_map.insert("a_col".to_string(), serde_json::json!(2)); + let mut qb = sqlx::QueryBuilder::::new(""); + sqlite_push_pk_where(&mut qb, &pk_map).unwrap(); + assert_eq!(qb.sql(), "\"a_col\" = ? AND \"z_col\" = ?"); + } + + #[test] + fn empty_pk_map_is_rejected() { + let pk_map: HashMap = HashMap::new(); + let mut qb = sqlx::QueryBuilder::::new(""); + assert!(sqlite_push_pk_where(&mut qb, &pk_map).is_err()); + } + + #[test] + fn double_quote_in_column_name_is_escaped() { + let mut pk_map = HashMap::new(); + pk_map.insert("a\"b".to_string(), serde_json::json!(1)); + let mut qb = sqlx::QueryBuilder::::new(""); + sqlite_push_pk_where(&mut qb, &pk_map).unwrap(); + assert_eq!(qb.sql(), "\"a\"\"b\" = ?"); + } +} diff --git a/src-tauri/src/plugins/driver.rs b/src-tauri/src/plugins/driver.rs index 038ef963..5a7dac37 100644 --- a/src-tauri/src/plugins/driver.rs +++ b/src-tauri/src/plugins/driver.rs @@ -545,14 +545,13 @@ impl DatabaseDriver for RpcDriver { &self, params: &ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, col_name: &str, new_val: serde_json::Value, schema: Option<&str>, max_blob_size: u64, ) -> Result { - let res = self.process.call("update_record", json!({ "params": params, "table": table, "pk_col": pk_col, "pk_val": pk_val, "col_name": col_name, "new_val": new_val, "schema": schema, "max_blob_size": max_blob_size })).await?; + let res = self.process.call("update_record", json!({ "params": params, "table": table, "pk_map": pk_map, "col_name": col_name, "new_val": new_val, "schema": schema, "max_blob_size": max_blob_size })).await?; serde_json::from_value(res).map_err(|e| e.to_string()) } @@ -560,11 +559,10 @@ impl DatabaseDriver for RpcDriver { &self, params: &ConnectionParams, table: &str, - pk_col: &str, - pk_val: serde_json::Value, + pk_map: &std::collections::HashMap, schema: Option<&str>, ) -> Result { - let res = self.process.call("delete_record", json!({ "params": params, "table": table, "pk_col": pk_col, "pk_val": pk_val, "schema": schema })).await?; + let res = self.process.call("delete_record", json!({ "params": params, "table": table, "pk_map": pk_map, "schema": schema })).await?; serde_json::from_value(res).map_err(|e| e.to_string()) } diff --git a/src/components/notebook/SqlCellResult.tsx b/src/components/notebook/SqlCellResult.tsx index 23e56163..78939476 100644 --- a/src/components/notebook/SqlCellResult.tsx +++ b/src/components/notebook/SqlCellResult.tsx @@ -79,7 +79,7 @@ export function SqlCellResult({ columns={result.columns} data={result.rows} tableName={null} - pkColumn={null} + pkColumns={null} readonly /> diff --git a/src/components/ui/BlobInput.tsx b/src/components/ui/BlobInput.tsx index 93c4f455..4f1a4fea 100644 --- a/src/components/ui/BlobInput.tsx +++ b/src/components/ui/BlobInput.tsx @@ -30,8 +30,7 @@ export interface BlobInputProps { className?: string; connectionId?: string | null; tableName?: string | null; - pkCol?: string | null; - pkVal?: unknown; + pkMap?: Record | null; colName?: string | null; schema?: string | null; } @@ -50,8 +49,7 @@ export const BlobInput = ({ className = "", connectionId, tableName, - pkCol, - pkVal, + pkMap, colName, schema, }: BlobInputProps) => { @@ -69,9 +67,8 @@ export const BlobInput = ({ metadata?.isTruncated && connectionId && tableName && - pkCol && - pkVal !== null && - pkVal !== undefined && + pkMap && + Object.keys(pkMap).length > 0 && colName; // Build a data URL for image preview from the BLOB wire format (non-file-ref, non-truncated) @@ -123,8 +120,7 @@ export const BlobInput = ({ connectionId, table: tableName, colName, - pkCol, - pkVal, + pkMap, ...(schema ? { schema } : {}), }) .then((dataUrl) => { @@ -136,7 +132,7 @@ export const BlobInput = ({ return () => { cancelled = true; }; - }, [isImage, canFetchFull, connectionId, tableName, colName, pkCol, pkVal, schema]); + }, [isImage, canFetchFull, connectionId, tableName, colName, pkMap, schema]); const effectiveImageDataUrl = imageDataUrl ?? fileRefPreviewUrl ?? dbPreviewUrl; @@ -192,8 +188,7 @@ export const BlobInput = ({ connectionId, table: tableName, colName, - pkCol, - pkVal, + pkMap, filePath, ...(schema ? { schema } : {}), }); diff --git a/src/components/ui/DataGrid.tsx b/src/components/ui/DataGrid.tsx index a5bdd9e6..9da4721f 100644 --- a/src/components/ui/DataGrid.tsx +++ b/src/components/ui/DataGrid.tsx @@ -41,6 +41,8 @@ import { getColumnSortState, calculateSelectionRange, toggleSetValue, + buildPkMap, + serializePkKey, type MergedRow, } from "../../utils/dataGrid"; import { isGeometricType, formatGeometricValue } from "../../utils/geometry"; @@ -75,7 +77,7 @@ interface DataGridProps { columns: string[]; data: unknown[][]; tableName?: string | null; - pkColumn?: string | null; + pkColumns?: string[] | null; autoIncrementColumns?: string[]; defaultValueColumns?: string[]; nullableColumns?: string[]; @@ -117,7 +119,7 @@ export const DataGrid = React.memo( columns, data, tableName, - pkColumn, + pkColumns, autoIncrementColumns, defaultValueColumns, nullableColumns, @@ -234,12 +236,15 @@ export const DataGrid = React.memo( [onSelectionChange], ); - // Pre-calculate pkIndex once for O(1) lookup instead of O(n) in render loop - const pkIndexMap = useMemo(() => { - if (!pkColumn) return null; - const pkIndex = columns.indexOf(pkColumn); - return pkIndex >= 0 ? pkIndex : null; - }, [columns, pkColumn]); + // Pre-calculate pkIndex array once for O(1) lookup instead of O(n) in render loop + const pkIndexMaps = useMemo((): number[] => { + if (!pkColumns || pkColumns.length === 0) return []; + const indices = pkColumns.map((col) => columns.indexOf(col)); + // If any PK column is absent from the result set, disable editing entirely + // to avoid partial WHERE clauses that could match multiple rows. + if (indices.some((idx) => idx < 0)) return []; + return indices; + }, [columns, pkColumns]); // Create column type map for O(1) lookup during cell rendering const columnTypeMap = useMemo(() => { @@ -269,15 +274,15 @@ export const DataGrid = React.memo( const buildRowLabel = useCallback( (rowData: unknown[], rowIndex: number, isInsertion: boolean): string => { if (isInsertion) return t("dataGrid.newRow", { defaultValue: "NEW" }); - if (pkColumn && pkIndexMap !== null) { - const pkVal = rowData[pkIndexMap]; + if (pkColumns && pkColumns.length > 0 && pkIndexMaps.length > 0) { + const pkVal = rowData[pkIndexMaps[0]]; if (pkVal !== null && pkVal !== undefined && pkVal !== "") { - return `${pkColumn}=${String(pkVal)}`; + return `${pkColumns[0]}=${String(pkVal)}`; } } return `Row ${rowIndex + 1}`; }, - [pkColumn, pkIndexMap, t], + [pkColumns, pkIndexMaps, t], ); const openJsonViewerWindow = useCallback( @@ -296,13 +301,14 @@ export const DataGrid = React.memo( let cellKey: string | null = null; const canSaveBack = (isInsertion && !!tempId) || - (!isInsertion && pkIndexMap !== null); + (!isInsertion && pkIndexMaps.length > 0); if (isInsertion && tempId) { cellKey = `ins:${tempId}:${colName}`; - } else if (!isInsertion && pkIndexMap !== null) { - const pkVal = rowData[pkIndexMap]; - if (pkVal !== null && pkVal !== undefined && pkVal !== "") { - cellKey = `pk:${String(pkVal)}:${colName}`; + } else if (!isInsertion && pkIndexMaps.length > 0) { + const pkMapVal = buildPkMap(pkColumns!, rowData, pkIndexMaps); + const serialized = serializePkKey(pkMapVal); + if (serialized !== "" && serialized !== "null" && serialized !== "undefined") { + cellKey = `pk:${serialized}:${colName}`; } } const sessionId = await invoke("open_json_viewer_window", { @@ -323,7 +329,7 @@ export const DataGrid = React.memo( console.error("Failed to open JSON viewer window:", e); } }, - [buildRowLabel, pkIndexMap], + [buildRowLabel, pkIndexMaps, pkColumns], ); useEffect(() => { @@ -338,16 +344,16 @@ export const DataGrid = React.memo( const { colName, rowData, isInsertion, tempId } = session; if (isInsertion && onPendingInsertionChange && tempId) { onPendingInsertionChange(tempId, colName, value); - } else if (!isInsertion && onPendingChange && pkIndexMap !== null) { - const pkVal = rowData[pkIndexMap]; - onPendingChange(pkVal, colName, value); + } else if (!isInsertion && onPendingChange && pkIndexMaps.length > 0) { + const pkMapVal = buildPkMap(pkColumns!, rowData, pkIndexMaps); + onPendingChange(pkMapVal, colName, value); } }, ); return () => { unlistenPromise.then((fn) => fn()); }; - }, [onPendingChange, onPendingInsertionChange, pkIndexMap]); + }, [onPendingChange, onPendingInsertionChange, pkIndexMaps, pkColumns]); const fksByColumn = useMemo( () => pickPrimaryForeignKeyByColumn(foreignKeys), @@ -445,14 +451,14 @@ export const DataGrid = React.memo( columns.forEach((col, idx) => { rowData[col] = rowArray[idx]; }); - if (!isInsertion && pkIndexMap !== null) { - const pkVal = rowArray[pkIndexMap]; - const pending = pendingChanges?.[String(pkVal)]?.changes; + if (!isInsertion && pkIndexMaps.length > 0) { + const pkMapVal = buildPkMap(pkColumns!, rowArray, pkIndexMaps); + const pending = pendingChanges?.[serializePkKey(pkMapVal)]?.changes; if (pending) Object.assign(rowData, pending); } return rowData; }, - [columns, pkIndexMap, pendingChanges], + [columns, pkIndexMaps, pkColumns, pendingChanges], ); const handleCellDoubleClick = useCallback( @@ -461,7 +467,7 @@ export const DataGrid = React.memo( const mergedRow = mergedRows[rowIndex]; if (!mergedRow) return; - if (mergedRow.type !== "insertion" && !pkColumn) return; + if (mergedRow.type !== "insertion" && pkIndexMaps.length === 0) return; const colName = columns[colIndex]; const colType = columnTypeMap?.get(colName); @@ -514,7 +520,7 @@ export const DataGrid = React.memo( tableName, readonlyProp, mergedRows, - pkColumn, + pkColumns, columns, columnTypeMap, columnLengthMap, @@ -579,17 +585,17 @@ export const DataGrid = React.memo( return; } - // PK Value - check pkIndexMap is valid - if (pkIndexMap === null) { + // PK Value - check pkIndexMaps is valid + if (pkIndexMaps.length === 0 || !pkColumns) { setEditingCell(null); return; } - const pkVal = row[pkIndexMap]; + const pkMapVal = buildPkMap(pkColumns, row, pkIndexMaps); const colName = columns[colIndex]; if (onPendingChange) { // If value matches original, pass undefined to remove the pending change - onPendingChange(pkVal, colName, isUnchanged ? undefined : value); + onPendingChange(pkMapVal, colName, isUnchanged ? undefined : value); setEditingCell(null); return; } @@ -601,8 +607,7 @@ export const DataGrid = React.memo( await invoke("update_record", { connectionId, table: tableName, - pkCol: pkColumn, - pkVal, + pkMap: pkMapVal, colName, newVal: value, ...(activeSchema ? { schema: activeSchema } : {}), @@ -625,8 +630,8 @@ export const DataGrid = React.memo( columns, onPendingInsertionChange, onPendingChange, - pkIndexMap, - pkColumn, + pkIndexMaps, + pkColumns, connectionId, activeSchema, onRefresh, @@ -834,16 +839,16 @@ export const DataGrid = React.memo( return; } - // For existing rows, need pkColumn - if (!pkColumn || pkIndexMap === null) return; + // For existing rows, need pkColumns + if (!pkColumns || pkIndexMaps.length === 0) return; - const pkVal = contextMenu.row[pkIndexMap]; - const pkValStr = String(pkVal); + const pkMapVal = buildPkMap(pkColumns, contextMenu.row, pkIndexMaps); + const pkValStr = serializePkKey(pkMapVal); // Handle pending deletion revert const isPendingDelete = pendingDeletions?.[pkValStr] !== undefined; if (isPendingDelete && onRevertDeletion) { - onRevertDeletion(pkVal); + onRevertDeletion(pkMapVal); setContextMenu(null); return; } @@ -853,7 +858,7 @@ export const DataGrid = React.memo( if (rowPendingChanges && onPendingChange) { // Revert all pending changes for this row by setting them to undefined Object.keys(rowPendingChanges.changes).forEach((colName) => { - onPendingChange(pkVal, colName, undefined); + onPendingChange(pkMapVal, colName, undefined); }); setContextMenu(null); return; @@ -865,8 +870,8 @@ export const DataGrid = React.memo( onPendingChange, onRevertDeletion, onDiscardInsertion, - pkColumn, - pkIndexMap, + pkColumns, + pkIndexMaps, pendingChanges, pendingDeletions, ]); @@ -879,8 +884,8 @@ export const DataGrid = React.memo( if (mergedRow.type === "insertion" && mergedRow.tempId && onDiscardInsertion) { onDiscardInsertion(mergedRow.tempId); - } else if (mergedRow.type === "existing" && pkColumn && pkIndexMap !== null) { - pkVals.push(mergedRow.rowData[pkIndexMap]); + } else if (mergedRow.type === "existing" && pkColumns && pkIndexMaps.length > 0) { + pkVals.push(buildPkMap(pkColumns, mergedRow.rowData, pkIndexMaps)); } } @@ -892,7 +897,7 @@ export const DataGrid = React.memo( pkVals.forEach((v) => onMarkForDeletion(v)); } } - }, [mergedRows, onDiscardInsertion, onMarkForDeletion, onMarkMultipleForDeletion, pkColumn, pkIndexMap]); + }, [mergedRows, onDiscardInsertion, onMarkForDeletion, onMarkMultipleForDeletion, pkColumns, pkIndexMaps]); const deleteSelectedRow = useCallback(() => { if (!contextMenu) return; @@ -967,13 +972,13 @@ export const DataGrid = React.memo( if (isInsertion && onPendingInsertionChange && mergedRow.tempId) { onPendingInsertionChange(mergedRow.tempId, colName, value); - } else if (onPendingChange && pkIndexMap !== null) { - const pkVal = contextMenu.row[pkIndexMap]; - onPendingChange(pkVal, colName, value); + } else if (onPendingChange && pkIndexMaps.length > 0) { + const pkMapVal = buildPkMap(pkColumns!, contextMenu.row, pkIndexMaps); + onPendingChange(pkMapVal, colName, value); } setContextMenu(null); }, - [contextMenu, onPendingInsertionChange, onPendingChange, pkIndexMap], + [contextMenu, onPendingInsertionChange, onPendingChange, pkIndexMaps, pkColumns], ); const setCellGenerate = useCallback( @@ -1003,8 +1008,9 @@ export const DataGrid = React.memo( const formatted = formatDateTime(parseDateTime(raw), dateMode); if (isInsertion && onPendingInsertionChange && mergedRow?.tempId) { onPendingInsertionChange(mergedRow.tempId, colName, formatted); - } else if (onPendingChange && pkIndexMap !== null) { - onPendingChange(row[pkIndexMap], colName, formatted); + } else if (onPendingChange && pkIndexMaps.length > 0) { + const pkMapVal = buildPkMap(pkColumns!, row, pkIndexMaps); + onPendingChange(pkMapVal, colName, formatted); } }) .catch((err) => { @@ -1016,7 +1022,8 @@ export const DataGrid = React.memo( columnTypeMap, onPendingInsertionChange, onPendingChange, - pkIndexMap, + pkIndexMaps, + pkColumns, t, showAlert, ]); @@ -1141,7 +1148,7 @@ export const DataGrid = React.memo( autoIncrementColumns, defaultValueColumns, nullableColumns, - pkColumn, + pkColumns, pendingChanges, columnTypeMap, columnLengthMap, @@ -1149,7 +1156,7 @@ export const DataGrid = React.memo( fksByColumn, t, mergedRows, - pkIndexMap, + pkIndexMaps, parentViewportWidth, readonly: readonlyProp, updateSelection, @@ -1177,7 +1184,7 @@ export const DataGrid = React.memo( autoIncrementColumns, defaultValueColumns, nullableColumns, - pkColumn, + pkColumns, pendingChanges, columnTypeMap, columnLengthMap, @@ -1185,7 +1192,7 @@ export const DataGrid = React.memo( fksByColumn, t, mergedRows, - pkIndexMap, + pkIndexMaps, parentViewportWidth, readonlyProp, updateSelection, @@ -1283,8 +1290,8 @@ export const DataGrid = React.memo( const mergedRow = mergedRows[rowIndex]; const isInsertion = mergedRow?.type === "insertion"; const pkVal = - pkIndexMap !== null - ? String(rowOriginal[pkIndexMap]) + pkIndexMaps.length > 0 && pkColumns + ? serializePkKey(buildPkMap(pkColumns, rowOriginal as unknown[], pkIndexMaps)) : null; const isPendingDelete = !isInsertion && pkVal @@ -1320,8 +1327,8 @@ export const DataGrid = React.memo( // Check if this row has any pending changes, deletions, or is an insertion const isInsertion = contextMenu.mergedRow?.type === "insertion"; const pkVal = - pkIndexMap !== null - ? String(contextMenu.row[pkIndexMap]) + pkIndexMaps.length > 0 && pkColumns + ? serializePkKey(buildPkMap(pkColumns, contextMenu.row, pkIndexMaps)) : null; const hasPendingChanges = !isInsertion && pkVal && pendingChanges?.[pkVal] !== undefined; @@ -1578,7 +1585,7 @@ export const DataGrid = React.memo( focusField={sidebarRowData.focusField} connectionId={connectionId} tableName={tableName} - pkColumn={pkColumn} + pkColumns={pkColumns} schema={activeSchema} onChange={(colName, value) => { // Get the merged row to determine if it's an insertion or existing row @@ -1602,14 +1609,14 @@ export const DataGrid = React.memo( } else if ( !isInsertion && onPendingChange && - pkColumn && - pkIndexMap !== null + pkColumns && + pkIndexMaps.length > 0 ) { // Handle existing row updates const rowData = mergedRow.rowData; if (rowData) { - const pkVal = rowData[pkIndexMap]; - onPendingChange(pkVal, colName, value); + const pkMapVal = buildPkMap(pkColumns, rowData, pkIndexMaps); + onPendingChange(pkMapVal, colName, value); } } }} diff --git a/src/components/ui/DataGridRow.tsx b/src/components/ui/DataGridRow.tsx index 61318a27..002a56d8 100644 --- a/src/components/ui/DataGridRow.tsx +++ b/src/components/ui/DataGridRow.tsx @@ -5,6 +5,8 @@ import { resolveInsertionCellDisplay, resolveExistingCellDisplay, getCellStateClass, + buildPkMap, + serializePkKey, type ColumnDisplayInfo, type MergedRow, } from "../../utils/dataGrid"; @@ -32,7 +34,7 @@ export interface RowCtx { autoIncrementColumns?: string[]; defaultValueColumns?: string[]; nullableColumns?: string[]; - pkColumn?: string | null; + pkColumns?: string[] | null; pendingChanges?: Record< string, { pkOriginalValue: unknown; changes: Record } @@ -43,7 +45,7 @@ export interface RowCtx { fksByColumn: Map; t: (key: string, opts?: Record) => string; mergedRows: MergedRow[]; - pkIndexMap: number | null; + pkIndexMaps: number[]; parentViewportWidth: number; readonly: boolean | undefined; updateSelection: (s: Set) => void; @@ -150,7 +152,7 @@ export const MemoRow = React.memo(function MemoRow(rowCtx: MemoRowProps) { autoIncrementColumns, defaultValueColumns, nullableColumns, - pkColumn, + pkColumns, pendingChanges, columnTypeMap, columnLengthMap, @@ -158,7 +160,7 @@ export const MemoRow = React.memo(function MemoRow(rowCtx: MemoRowProps) { fksByColumn, t, mergedRows, - pkIndexMap, + pkIndexMaps, parentViewportWidth, readonly: readonlyProp, updateSelection, @@ -254,7 +256,7 @@ export const MemoRow = React.memo(function MemoRow(rowCtx: MemoRowProps) { : resolveExistingCellDisplay( cellValue, pkVal, - pkColumn, + pkColumns, pendingChanges, columnInfo, ); @@ -619,12 +621,11 @@ export const MemoRow = React.memo(function MemoRow(rowCtx: MemoRowProps) { const mergedRow = mergedRows[rowIndex]; const pendingExpansionValue = (() => { if (!mergedRow) return undefined; - if (mergedRow.type === "existing" && pkIndexMap !== null) { - const pkVal = mergedRow.rowData[pkIndexMap]; + if (mergedRow.type === "existing" && pkIndexMaps.length > 0 && pkColumns) { + const pkMapVal = buildPkMap(pkColumns, mergedRow.rowData, pkIndexMaps); + const pkKeyStr = serializePkKey(pkMapVal); const pendingVal = - pkVal !== null && pkVal !== undefined && pkVal !== "" - ? pendingChanges?.[String(pkVal)]?.changes?.[expColName] - : undefined; + pendingChanges?.[pkKeyStr]?.changes?.[expColName]; if (pendingVal !== undefined) return pendingVal; } return mergedRow.rowData?.[expandedCell.colIndex]; @@ -644,10 +645,11 @@ export const MemoRow = React.memo(function MemoRow(rowCtx: MemoRowProps) { } else if ( mergedRow.type === "existing" && onPendingChange && - pkIndexMap !== null + pkIndexMaps.length > 0 && + pkColumns ) { - const pkVal = mergedRow.rowData[pkIndexMap]; - onPendingChange(pkVal, expColName, next); + const pkMapVal = buildPkMap(pkColumns, mergedRow.rowData, pkIndexMaps); + onPendingChange(pkMapVal, expColName, next); } setExpandedCell(null); }; diff --git a/src/components/ui/FieldEditor.tsx b/src/components/ui/FieldEditor.tsx index 28621167..44f234b5 100644 --- a/src/components/ui/FieldEditor.tsx +++ b/src/components/ui/FieldEditor.tsx @@ -29,8 +29,7 @@ export interface FieldEditorProps { detectJsonInTextColumns?: boolean; connectionId?: string | null; tableName?: string | null; - pkCol?: string | null; - pkVal?: unknown; + pkMap?: Record | null; schema?: string | null; } @@ -54,8 +53,7 @@ export const FieldEditor = ({ detectJsonInTextColumns = false, connectionId, tableName, - pkCol, - pkVal, + pkMap, schema, }: FieldEditorProps) => { const { t } = useTranslation(); @@ -102,8 +100,7 @@ export const FieldEditor = ({ placeholder={defaultPlaceholder} connectionId={connectionId} tableName={tableName} - pkCol={pkCol} - pkVal={pkVal} + pkMap={pkMap} colName={name} schema={schema} /> diff --git a/src/components/ui/ResultEntryContent.tsx b/src/components/ui/ResultEntryContent.tsx index 7338fe5a..cfae6834 100644 --- a/src/components/ui/ResultEntryContent.tsx +++ b/src/components/ui/ResultEntryContent.tsx @@ -68,7 +68,7 @@ export function ResultEntryContent({ columns={entry.result.columns} data={entry.result.rows} tableName={null} - pkColumn={null} + pkColumns={null} connectionId={connectionId} selectedRows={new Set()} onSelectionChange={() => {}} @@ -115,7 +115,7 @@ export function ResultEntryContent({ columns={entry.result.columns} data={entry.result.rows} tableName={null} - pkColumn={null} + pkColumns={null} connectionId={connectionId} selectedRows={new Set()} onSelectionChange={() => {}} diff --git a/src/components/ui/RowEditorSidebar.tsx b/src/components/ui/RowEditorSidebar.tsx index a9d06d92..8ec66a04 100644 --- a/src/components/ui/RowEditorSidebar.tsx +++ b/src/components/ui/RowEditorSidebar.tsx @@ -1,4 +1,4 @@ -import { useRef, useEffect } from "react"; +import { useRef, useEffect, useMemo } from "react"; import { useTranslation } from "react-i18next"; import { X } from "lucide-react"; import { FieldEditor } from "./FieldEditor"; @@ -22,7 +22,7 @@ interface RowEditorSidebarProps { detectJsonInTextColumns?: boolean; connectionId?: string | null; tableName?: string | null; - pkColumn?: string | null; + pkColumns?: string[] | null; schema?: string | null; } @@ -42,7 +42,7 @@ export const RowEditorSidebar = ({ detectJsonInTextColumns = false, connectionId, tableName, - pkColumn, + pkColumns, schema, }: RowEditorSidebarProps) => { const { t } = useTranslation(); @@ -52,6 +52,14 @@ export const RowEditorSidebar = ({ onChange: (fieldName, value) => onChange(fieldName, value), }); + const pkMap = useMemo( + () => + pkColumns && pkColumns.length > 0 + ? Object.fromEntries(pkColumns.map((col) => [col, rowData[col]])) + : undefined, + [pkColumns, rowData], + ); + // Refs to track field containers for scrolling const fieldRefs = useRef>({}); @@ -169,8 +177,7 @@ export const RowEditorSidebar = ({ isNullable={nullableColumns?.includes(column.name)} connectionId={connectionId} tableName={tableName} - pkCol={pkColumn} - pkVal={pkColumn ? rowData[pkColumn] : undefined} + pkMap={pkMap} schema={schema} /> { pendingInsertions, selectedRows, result, - pkColumn, + pkColumns, } = activeTab; const hasGlobalPending = (pendingChanges && Object.keys(pendingChanges).length > 0) || @@ -489,16 +490,16 @@ export const Editor = () => { } // This is an existing row - check for changes/deletions - if (!result || !pkColumn) return false; - const pkIndex = result.columns.indexOf(pkColumn); - if (pkIndex === -1) return false; + if (!result || !pkColumns || pkColumns.length === 0) return false; + const pkIndices = pkColumns.map((c) => result.columns.indexOf(c)); + if (pkIndices.some((i) => i === -1)) return false; const row = result.rows[rowIndex]; if (!row) return false; - const pkVal = String(row[pkIndex]); + const pkKey = serializePkKey(buildPkMap(pkColumns, row, pkIndices)); return ( - (pendingChanges && pendingChanges[pkVal]) || - (pendingDeletions && pendingDeletions[pkVal]) + (pendingChanges && pendingChanges[pkKey]) || + (pendingDeletions && pendingDeletions[pkKey]) ); }); }, [activeTab]); @@ -547,7 +548,7 @@ export const Editor = () => { return [] as ForeignKey[]; }), ]); - const pk = cols.find((c) => c.is_pk); + const pks = cols.filter((c) => c.is_pk).map((c) => c.name); const autoInc = cols .filter((c) => c.is_auto_increment) .map((c) => c.name); @@ -560,7 +561,7 @@ export const Editor = () => { const targetId = tabId || activeTabId; if (targetId) updateTab(targetId, { - pkColumn: pk ? pk.name : null, + pkColumns: pks.length > 0 ? pks : null, autoIncrementColumns: autoInc, defaultValueColumns: defaultVal, nullableColumns: nullable, @@ -569,11 +570,11 @@ export const Editor = () => { }); } catch (e) { console.error("Failed to fetch PK:", e); - // Even if PK fetch fails, set pkColumn to null to unblock the UI + // Even if PK fetch fails, set pkColumns to null to unblock the UI const targetId = tabId || activeTabId; if (targetId) updateTab(targetId, { - pkColumn: null, + pkColumns: null, autoIncrementColumns: [], defaultValueColumns: [], nullableColumns: [], @@ -753,7 +754,7 @@ export const Editor = () => { // Fetch column metadata in the background; tab updates when ready fetchPkColumn(tableName, targetTabId, targetTab?.schema ?? undefined); } else { - updateTab(targetTabId, { pkColumn: null }); + updateTab(targetTabId, { pkColumns: null }); } if (shouldRecordHistory) { @@ -1410,7 +1411,7 @@ export const Editor = () => { const currentTab = tabsRef.current.find((t) => t.id === tabId); if (!currentTab) return; - const pkKey = String(pkVal); + const pkKey = serializePkKey(pkVal as Record); const currentPending = currentTab.pendingChanges || {}; const rowEntry = currentPending[pkKey] || { pkOriginalValue: pkVal, @@ -1474,13 +1475,15 @@ export const Editor = () => { activeTab.selectedRows.forEach((rowIndex) => { if (rowIndex < existingRowCount) { // Existing row - add to pending deletions - if (activeTab.result && activeTab.pkColumn) { - const pkIndex = activeTab.result.columns.indexOf(activeTab.pkColumn); - if (pkIndex !== -1) { + if (activeTab.result && activeTab.pkColumns && activeTab.pkColumns.length > 0) { + const pkCols = activeTab.pkColumns; + const pkIndices = pkCols.map((c) => activeTab.result!.columns.indexOf(c)); + if (pkIndices.every((i) => i !== -1)) { const row = activeTab.result.rows[rowIndex]; if (row) { - const pkVal = row[pkIndex]; - newPendingDeletions[String(pkVal)] = pkVal; + const pkMapVal = buildPkMap(pkCols, row, pkIndices); + const pkKey = serializePkKey(pkMapVal); + newPendingDeletions[pkKey] = pkMapVal; } } } @@ -1558,7 +1561,7 @@ export const Editor = () => { const currentTab = tabsRef.current.find((t) => t.id === tabId); if (!currentTab?.pendingDeletions) return; - const pkKey = String(pkVal); + const pkKey = serializePkKey(pkVal as Record); const newPendingDeletions = { ...currentTab.pendingDeletions }; delete newPendingDeletions[pkKey]; @@ -1579,7 +1582,7 @@ export const Editor = () => { const currentTab = tabsRef.current.find((t) => t.id === tabId); if (!currentTab) return; - const pkKey = String(pkVal); + const pkKey = serializePkKey(pkVal as Record); const currentPendingDeletions = currentTab.pendingDeletions || {}; const newPendingDeletions = { ...currentPendingDeletions, @@ -1600,7 +1603,7 @@ export const Editor = () => { const newPendingDeletions = { ...(currentTab.pendingDeletions || {}) }; for (const pkVal of pkVals) { - newPendingDeletions[String(pkVal)] = pkVal; + newPendingDeletions[serializePkKey(pkVal as Record)] = pkVal; } updateTab(tabId, { pendingDeletions: newPendingDeletions }); @@ -1712,11 +1715,11 @@ export const Editor = () => { }; } - // Ensure pkColumn and autoIncrementColumns are set - if (!activeTab.pkColumn) { - const pk = columns.find((c) => c.is_pk); - if (pk) { - updates.pkColumn = pk.name; + // Ensure pkColumns and autoIncrementColumns are set + if (!activeTab.pkColumns || activeTab.pkColumns.length === 0) { + const pks = columns.filter((c) => c.is_pk).map((c) => c.name); + if (pks.length > 0) { + updates.pkColumns = pks; } } @@ -1768,52 +1771,52 @@ export const Editor = () => { const handleSubmitChanges = useCallback(async () => { if (!activeTab || !activeTab.activeTable || !activeConnectionId) return; - // pkColumn is required for updates/deletions but not for insertions-only - const hasPkColumn = !!activeTab.pkColumn; + // pkColumns is required for updates/deletions but not for insertions-only + const hasPkColumns = !!(activeTab.pkColumns && activeTab.pkColumns.length > 0); const { pendingChanges, pendingDeletions, pendingInsertions, activeTable, - pkColumn, + pkColumns, selectedRows, } = activeTab; - const updates: { pkVal: unknown; colName: string; newVal: unknown }[] = []; - const deletions: unknown[] = []; + const updates: { pkVal: Record; colName: string; newVal: unknown }[] = []; + const deletions: Record[] = []; const insertions: { tempId: string; data: Record }[] = []; // Filter pending changes by selected rows IF there is a selection AND applyToAll is false const hasSelection = !applyToAll && selectedRows && selectedRows.length > 0; const selectedPkSet = new Set(); - if (hasSelection && activeTab.result && hasPkColumn && pkColumn) { - const pkIndex = activeTab.result.columns.indexOf(pkColumn); - if (pkIndex !== -1) { + if (hasSelection && activeTab.result && hasPkColumns && pkColumns) { + const pkIndices = pkColumns.map((c) => activeTab.result!.columns.indexOf(c)); + if (pkIndices.every((i) => i !== -1)) { selectedRows.forEach((rowIndex) => { const row = activeTab.result!.rows[rowIndex]; - if (row) selectedPkSet.add(String(row[pkIndex])); + if (row) selectedPkSet.add(serializePkKey(buildPkMap(pkColumns, row, pkIndices))); }); } } - if (hasPkColumn && pkColumn && pendingChanges) { + if (hasPkColumns && pkColumns && pendingChanges) { for (const [pkKey, rowData] of Object.entries(pendingChanges)) { // Apply filter if selection exists (and applyToAll is false) if (hasSelection && !selectedPkSet.has(pkKey)) continue; const { pkOriginalValue, changes } = rowData; for (const [colName, newVal] of Object.entries(changes)) { - updates.push({ pkVal: pkOriginalValue, colName, newVal }); + updates.push({ pkVal: pkOriginalValue as Record, colName, newVal }); } } } - if (hasPkColumn && pkColumn && pendingDeletions) { + if (hasPkColumns && pkColumns && pendingDeletions) { for (const [pkKey, pkVal] of Object.entries(pendingDeletions)) { // Apply filter if selection exists (and applyToAll is false) if (hasSelection && !selectedPkSet.has(pkKey)) continue; - deletions.push(pkVal); + deletions.push(pkVal as Record); } } @@ -1897,12 +1900,11 @@ export const Editor = () => { // Deletions if (deletions.length > 0) { promises.push( - ...deletions.map((pkVal) => + ...deletions.map((pkMap) => invoke("delete_record", { connectionId: activeConnectionId, table: activeTable, - pkCol: pkColumn, - pkVal, + pkMap, ...(activeSchema ? { schema: activeSchema } : {}), ...databaseParam, }), @@ -1917,8 +1919,7 @@ export const Editor = () => { invoke("update_record", { connectionId: activeConnectionId, table: activeTable, - pkCol: pkColumn, - pkVal: u.pkVal, + pkMap: u.pkVal, colName: u.colName, newVal: u.newVal, ...(activeSchema ? { schema: activeSchema } : {}), @@ -1951,8 +1952,8 @@ export const Editor = () => { const newPendingInsertions = { ...(pendingInsertions || {}) }; // Partial cleanup - remove only processed changes - updates.forEach((u) => delete newPendingChanges[String(u.pkVal)]); - deletions.forEach((d) => delete newPendingDeletions[String(d)]); + updates.forEach((u) => delete newPendingChanges[serializePkKey(u.pkVal)]); + deletions.forEach((d) => delete newPendingDeletions[serializePkKey(d as Record)]); insertions.forEach((i) => delete newPendingInsertions[i.tempId]); // Cleanup empty change objects @@ -2057,7 +2058,7 @@ export const Editor = () => { const { selectedRows, result, - pkColumn, + pkColumns, pendingChanges, pendingDeletions, pendingInsertions, @@ -2083,12 +2084,12 @@ export const Editor = () => { }); // For existing rows, also collect their PK values - if (result && pkColumn) { - const pkIndex = result.columns.indexOf(pkColumn); - if (pkIndex !== -1) { + if (result && pkColumns && pkColumns.length > 0) { + const pkIndices = pkColumns.map((c) => result.columns.indexOf(c)); + if (pkIndices.every((i) => i !== -1)) { selectedRows.forEach((rowIndex) => { const row = result.rows[rowIndex]; - if (row) selectedPkSet.add(String(row[pkIndex])); + if (row) selectedPkSet.add(serializePkKey(buildPkMap(pkColumns, row, pkIndices))); }); } } @@ -3339,7 +3340,7 @@ export const Editor = () => { columns={activeTab.result?.columns || []} data={activeTab.result?.rows || []} tableName={activeTab.activeTable} - pkColumn={activeTab.pkColumn} + pkColumns={activeTab.pkColumns} autoIncrementColumns={activeTab.autoIncrementColumns} defaultValueColumns={activeTab.defaultValueColumns} nullableColumns={activeTab.nullableColumns} diff --git a/src/types/editor.ts b/src/types/editor.ts index 4e4df966..79fb2190 100644 --- a/src/types/editor.ts +++ b/src/types/editor.ts @@ -61,7 +61,7 @@ export interface QueryResultEntry { isLoading: boolean; page: number; activeTable: string | null; - pkColumn: string | null; + pkColumns: string[] | null; } import type { NotebookState } from "./notebook"; @@ -88,7 +88,7 @@ export interface Tab { executionTime: number | null; page: number; activeTable: string | null; - pkColumn: string | null; + pkColumns: string[] | null; autoIncrementColumns?: string[]; // Names of auto-increment columns defaultValueColumns?: string[]; // Names of columns with default values nullableColumns?: string[]; // Names of nullable columns diff --git a/src/utils/dataGrid.ts b/src/utils/dataGrid.ts index f67e6026..fffa64e9 100644 --- a/src/utils/dataGrid.ts +++ b/src/utils/dataGrid.ts @@ -10,6 +10,31 @@ import { isJsonColumn } from "./json"; /** Sentinel value indicating that the database DEFAULT value should be used */ export const USE_DEFAULT_SENTINEL = "__USE_DEFAULT__"; +/** Build an object mapping PK column names to their values from a data row. */ +export function buildPkMap( + pkColumns: string[], + row: unknown[], + pkIndices: number[], +): Record { + const map: Record = {}; + for (let i = 0; i < pkColumns.length; i++) { + map[pkColumns[i]] = row[pkIndices[i]]; + } + return map; +} + +/** + * Produce a stable string key for a pk map to use as a pendingChanges key. + * Keys are sorted alphabetically before serializing so the result is + * deterministic regardless of insertion order. + */ +export function serializePkKey(pkMap: Record): string { + const entries = Object.entries(pkMap).sort(([a], [b]) => + a.localeCompare(b), + ); + return JSON.stringify(Object.fromEntries(entries)); +} + function cellValuesEqual(a: unknown, b: unknown): boolean { if (a === b) return true; if (typeof a === "object" || typeof b === "object") { @@ -218,7 +243,7 @@ export function resolveInsertionCellDisplay( export function resolveExistingCellDisplay( cellValue: unknown, pkVal: string | null, - pkColumn: string | null | undefined, + pkColumns: string[] | null | undefined, pendingChanges: | Record< string, @@ -227,9 +252,10 @@ export function resolveExistingCellDisplay( | undefined, columnInfo: ColumnDisplayInfo, ): ResolvedCellDisplay { + const hasPk = pkColumns && pkColumns.length > 0; const pendingVal = - pkColumn && pkVal && pendingChanges?.[pkVal]?.changes?.[columnInfo.colName]; - const hasPendingChange = pkColumn && pkVal ? pendingVal !== undefined : false; + hasPk && pkVal && pendingChanges?.[pkVal]?.changes?.[columnInfo.colName]; + const hasPendingChange = hasPk && pkVal ? pendingVal !== undefined : false; let displayValue = hasPendingChange ? pendingVal : cellValue; const isModified = hasPendingChange && !cellValuesEqual(pendingVal, cellValue); diff --git a/src/utils/editor.ts b/src/utils/editor.ts index d3d5d814..9b6982ea 100644 --- a/src/utils/editor.ts +++ b/src/utils/editor.ts @@ -90,7 +90,7 @@ export function createInitialTabState( executionTime: null, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, isLoading: false, connectionId: connectionId || "", isEditorOpen: partial?.isEditorOpen ?? partial?.type !== "table", diff --git a/src/utils/multiResult.ts b/src/utils/multiResult.ts index ea6040c9..ad9d1041 100644 --- a/src/utils/multiResult.ts +++ b/src/utils/multiResult.ts @@ -18,7 +18,7 @@ export function createResultEntries( isLoading: true, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, })); } diff --git a/src/utils/tabCleaner.ts b/src/utils/tabCleaner.ts index 5df0d551..0a79c937 100644 --- a/src/utils/tabCleaner.ts +++ b/src/utils/tabCleaner.ts @@ -10,7 +10,7 @@ export interface CleanedTab { query: string; page: number; activeTable: string | null; - pkColumn: string | null; + pkColumns: string[] | null; connectionId: string; flowState?: Tab['flowState']; isEditorOpen?: boolean; @@ -39,7 +39,7 @@ export function cleanTabForStorage(tab: Tab): CleanedTab { query: tab.query, page: tab.page, activeTable: tab.activeTable, - pkColumn: tab.pkColumn, + pkColumns: tab.pkColumns, connectionId: tab.connectionId, flowState: tab.flowState, isEditorOpen: tab.isEditorOpen, @@ -68,7 +68,7 @@ export function restoreTabFromStorage(cleanedTab: Partial): Tab { query: cleanedTab.query || '', page: cleanedTab.page || 1, activeTable: cleanedTab.activeTable || null, - pkColumn: cleanedTab.pkColumn || null, + pkColumns: cleanedTab.pkColumns || null, connectionId: cleanedTab.connectionId || '', result: null, error: '', diff --git a/tests/components/ui/MultiResultPanel.test.tsx b/tests/components/ui/MultiResultPanel.test.tsx index 1fcb98d1..18f849c1 100644 --- a/tests/components/ui/MultiResultPanel.test.tsx +++ b/tests/components/ui/MultiResultPanel.test.tsx @@ -57,7 +57,7 @@ function makeEntry( isLoading: true, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, ...overrides, }; } diff --git a/tests/hooks/useEditor.test.ts b/tests/hooks/useEditor.test.ts index 1ba99ce3..30e91a27 100644 --- a/tests/hooks/useEditor.test.ts +++ b/tests/hooks/useEditor.test.ts @@ -27,7 +27,7 @@ describe('useEditor', () => { executionTime: null, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: 'conn-123', }; diff --git a/tests/utils/dataGrid.test.ts b/tests/utils/dataGrid.test.ts index 0a93f620..d3811e58 100644 --- a/tests/utils/dataGrid.test.ts +++ b/tests/utils/dataGrid.test.ts @@ -8,6 +8,8 @@ import { resolveInsertionCellDisplay, resolveExistingCellDisplay, getCellStateClass, + buildPkMap, + serializePkKey, type ColumnDisplayInfo, type CellClassParams, } from '../../src/utils/dataGrid'; @@ -640,4 +642,48 @@ describe('dataGrid utils', () => { expect(result).toContain('font-medium'); }); }); + + describe('buildPkMap', () => { + it('maps a single PK column to its value', () => { + expect(buildPkMap(['id'], [10, 'Alice'], [0])).toEqual({ id: 10 }); + }); + + it('maps composite PK columns to their values', () => { + expect(buildPkMap(['org_id', 'user_id'], [1, 2, 'extra'], [0, 1])).toEqual({ + org_id: 1, + user_id: 2, + }); + }); + + it('uses pkIndices to pick non-zero positions from the row', () => { + expect(buildPkMap(['id'], ['name', 'email', 99], [2])).toEqual({ id: 99 }); + }); + + it('handles null and string values', () => { + expect(buildPkMap(['a', 'b'], [null, 'hello'], [0, 1])).toEqual({ a: null, b: 'hello' }); + }); + }); + + describe('serializePkKey', () => { + it('serializes a single-key map as JSON', () => { + expect(serializePkKey({ id: 42 })).toBe('{"id":42}'); + }); + + it('sorts composite keys alphabetically regardless of insertion order', () => { + const pkMap: Record = { z_col: 1, a_col: 2 }; + expect(serializePkKey(pkMap)).toBe('{"a_col":2,"z_col":1}'); + }); + + it('produces the same key regardless of insertion order', () => { + expect(serializePkKey({ a: 1, b: 2 })).toBe(serializePkKey({ b: 2, a: 1 })); + }); + + it('handles null pk values', () => { + expect(serializePkKey({ id: null })).toBe('{"id":null}'); + }); + + it('handles string pk values', () => { + expect(serializePkKey({ slug: 'hello-world' })).toBe('{"slug":"hello-world"}'); + }); + }); }); diff --git a/tests/utils/editor.test.ts b/tests/utils/editor.test.ts index 751eed48..06da3bef 100644 --- a/tests/utils/editor.test.ts +++ b/tests/utils/editor.test.ts @@ -54,7 +54,7 @@ describe("editor", () => { executionTime: null, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, isLoading: false, connectionId: "conn-1", isEditorOpen: true, @@ -106,7 +106,7 @@ describe("editor", () => { executionTime: null, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: "conn-1", ...overrides, }); @@ -176,7 +176,7 @@ describe("editor", () => { executionTime: null, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: "conn-1", ...overrides, }); @@ -251,7 +251,7 @@ describe("editor", () => { executionTime: null, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: "conn-1", ...overrides, }); @@ -297,7 +297,7 @@ describe("editor", () => { executionTime: null, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: "conn-1", ...overrides, }); @@ -359,7 +359,7 @@ describe("editor", () => { executionTime: null, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: "conn-1", ...overrides, }); @@ -463,7 +463,7 @@ describe("editor", () => { executionTime: null, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: "conn-1", ...overrides, }); @@ -503,7 +503,7 @@ describe("editor", () => { executionTime: null, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: "conn-1", ...overrides, }); @@ -534,7 +534,7 @@ describe("editor", () => { executionTime: null, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: "conn-1", ...overrides, }); @@ -601,7 +601,7 @@ describe("editor", () => { executionTime: null, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: "conn-1", ...overrides, }); @@ -655,7 +655,7 @@ describe("editor", () => { executionTime: null, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: "conn-1", ...overrides, }); @@ -774,7 +774,7 @@ describe("editor", () => { executionTime: null, page: 1, activeTable: "users", - pkColumn: null, + pkColumns: null, connectionId: "conn-1", ...overrides, }); diff --git a/tests/utils/multiResult.test.ts b/tests/utils/multiResult.test.ts index da964ec9..99430316 100644 --- a/tests/utils/multiResult.test.ts +++ b/tests/utils/multiResult.test.ts @@ -26,7 +26,7 @@ function makeEntry(overrides: Partial = {}): QueryResultEntry isLoading: true, page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, ...overrides, }; } @@ -65,7 +65,7 @@ describe("multiResult", () => { expect(entry.executionTime).toBeNull(); expect(entry.page).toBe(1); expect(entry.activeTable).toBeNull(); - expect(entry.pkColumn).toBeNull(); + expect(entry.pkColumns).toBeNull(); } }); diff --git a/tests/utils/tabCleaner.test.ts b/tests/utils/tabCleaner.test.ts index 3db9d3ac..bd737bd5 100644 --- a/tests/utils/tabCleaner.test.ts +++ b/tests/utils/tabCleaner.test.ts @@ -12,7 +12,7 @@ describe('tabCleaner', () => { query: 'SELECT * FROM users', page: 2, activeTable: 'users', - pkColumn: 'id', + pkColumns: ['id'], connectionId: 'conn-456', isEditorOpen: true, filterClause: 'age > 18', @@ -38,7 +38,7 @@ describe('tabCleaner', () => { expect(cleaned.query).toBe('SELECT * FROM users'); expect(cleaned.page).toBe(2); expect(cleaned.activeTable).toBe('users'); - expect(cleaned.pkColumn).toBe('id'); + expect(cleaned.pkColumns).toEqual(['id']); expect(cleaned.connectionId).toBe('conn-456'); expect(cleaned.isEditorOpen).toBe(true); expect(cleaned.filterClause).toBe('age > 18'); @@ -64,7 +64,7 @@ describe('tabCleaner', () => { query: '', page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: 'conn-456', result: null, error: '', @@ -74,7 +74,7 @@ describe('tabCleaner', () => { const cleaned = cleanTabForStorage(tab); expect(cleaned.activeTable).toBeNull(); - expect(cleaned.pkColumn).toBeNull(); + expect(cleaned.pkColumns).toBeNull(); expect(cleaned.query).toBe(''); }); @@ -91,7 +91,7 @@ describe('tabCleaner', () => { query: '', page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: 'conn-456', flowState, result: null, @@ -112,7 +112,7 @@ describe('tabCleaner', () => { query: '', page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: 'conn-456', result: null, error: '', @@ -149,7 +149,7 @@ describe('tabCleaner', () => { query: 'SELECT * FROM users', page: 3, activeTable: 'users', - pkColumn: 'id', + pkColumns: ['id'], connectionId: 'conn-456', isEditorOpen: false, filterClause: 'status = "active"', @@ -180,7 +180,7 @@ describe('tabCleaner', () => { query: 'SELECT * FROM users', page: 2, activeTable: 'users', - pkColumn: 'id', + pkColumns: ['id'], connectionId: 'conn-456', isEditorOpen: true, filterClause: 'age > 18', @@ -195,7 +195,7 @@ describe('tabCleaner', () => { expect(restored.query).toBe('SELECT * FROM users'); expect(restored.page).toBe(2); expect(restored.activeTable).toBe('users'); - expect(restored.pkColumn).toBe('id'); + expect(restored.pkColumns).toEqual(['id']); expect(restored.connectionId).toBe('conn-456'); expect(restored.isEditorOpen).toBe(true); expect(restored.filterClause).toBe('age > 18'); @@ -223,7 +223,7 @@ describe('tabCleaner', () => { expect(restored.query).toBe(''); expect(restored.page).toBe(1); expect(restored.activeTable).toBeNull(); - expect(restored.pkColumn).toBeNull(); + expect(restored.pkColumns).toBeNull(); expect(restored.connectionId).toBe(''); }); @@ -236,7 +236,7 @@ describe('tabCleaner', () => { expect(restored.query).toBe(''); expect(restored.page).toBe(1); expect(restored.activeTable).toBeNull(); - expect(restored.pkColumn).toBeNull(); + expect(restored.pkColumns).toBeNull(); expect(restored.connectionId).toBe(''); expect(restored.result).toBeNull(); expect(restored.error).toBe(''); @@ -252,7 +252,7 @@ describe('tabCleaner', () => { query: 'SELECT * FROM orders', page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: 'conn-789', flowState: { nodes: [{ id: '1', type: 'table', position: { x: 0, y: 0 }, data: {} }], @@ -280,7 +280,7 @@ describe('tabCleaner', () => { query: '', page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: 'conn-456', notebookId: 'nb_abc123', }; @@ -300,7 +300,7 @@ describe('tabCleaner', () => { query: '', page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId: 'conn-rt', result: null, error: '', @@ -323,7 +323,7 @@ describe('tabCleaner', () => { query: 'SELECT * FROM products WHERE price > 100', page: 5, activeTable: 'products', - pkColumn: 'product_id', + pkColumns: ['product_id'], connectionId: 'conn-999', isEditorOpen: true, filterClause: 'price > 100', @@ -347,7 +347,7 @@ describe('tabCleaner', () => { expect(restored.query).toBe(originalTab.query); expect(restored.page).toBe(originalTab.page); expect(restored.activeTable).toBe(originalTab.activeTable); - expect(restored.pkColumn).toBe(originalTab.pkColumn); + expect(restored.pkColumns).toEqual(originalTab.pkColumns); expect(restored.connectionId).toBe(originalTab.connectionId); expect(restored.isEditorOpen).toBe(originalTab.isEditorOpen); expect(restored.filterClause).toBe(originalTab.filterClause); diff --git a/tests/utils/tabFilters.test.ts b/tests/utils/tabFilters.test.ts index 00ef7162..288bb034 100644 --- a/tests/utils/tabFilters.test.ts +++ b/tests/utils/tabFilters.test.ts @@ -15,7 +15,7 @@ const createMockTab = (id: string, connectionId: string, title: string = 'Tab'): query: '', page: 1, activeTable: null, - pkColumn: null, + pkColumns: null, connectionId, result: null, error: '',