From 06494afb6cdde757b1d610220e3b7252daf4d15a Mon Sep 17 00:00:00 2001 From: George Stagg Date: Tue, 17 Mar 2026 10:44:20 +0000 Subject: [PATCH 01/20] WIP: ODBC --- Cargo.toml | 4 + ODBC.md | 312 +++++++++++++++++++++++++++++++ ggsql-jupyter/Cargo.toml | 4 + ggsql-jupyter/src/connection.rs | 209 +++++++++++++++++++++ ggsql-jupyter/src/display.rs | 15 ++ ggsql-jupyter/src/executor.rs | 222 +++++++++++++++++++--- ggsql-jupyter/src/kernel.rs | 186 ++++++++++++++++++- ggsql-jupyter/src/lib.rs | 1 + ggsql-jupyter/src/main.rs | 7 +- ggsql-python/src/lib.rs | 4 + ggsql-vscode/src/connections.ts | 151 +++++++++++++++ ggsql-vscode/src/extension.ts | 10 + ggsql-vscode/src/manager.ts | 13 +- src/Cargo.toml | 5 +- src/execute/mod.rs | 2 +- src/reader/connection.rs | 33 +++- src/reader/duckdb.rs | 4 + src/reader/mod.rs | 110 +++++++---- src/reader/odbc.rs | 319 ++++++++++++++++++++++++++++++++ src/reader/polars_sql.rs | 4 + src/reader/sqlite.rs | 27 +++ 21 files changed, 1577 insertions(+), 65 deletions(-) create mode 100644 ODBC.md create mode 100644 ggsql-jupyter/src/connection.rs create mode 100644 ggsql-vscode/src/connections.ts create mode 100644 src/reader/odbc.rs diff --git a/Cargo.toml b/Cargo.toml index 87cfa10e..17d0d099 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,10 @@ arrow = { version = "56", default-features = false, features = ["ipc"] } postgres = "0.19" rusqlite = { version = "0.38", features = ["bundled", "chrono", "functions", "window"] } +# ODBC +odbc-api = "13" +toml_edit = "0.22" + # Writers plotters = "0.3" diff --git a/ODBC.md b/ODBC.md new file mode 100644 index 00000000..51cdcd50 --- /dev/null +++ b/ODBC.md @@ -0,0 +1,312 @@ +# Positron Connections Pane Integration for ggsql + +## Context + +ggsql's Jupyter kernel (`ggsql-jupyter`) and VS Code extension (`ggsql-vscode`) currently have no integration with Positron's Connections pane. The kernel is hardcoded to `duckdb://memory` with no way to configure the database connection. This plan adds: + +1. **Connection comm protocol** in the kernel — so database schemas appear in the Connections pane +2. **Connection drivers** in the extension — so users can create connections via the "New Connection" dialog +3. **Generic ODBC reader** in the library — supporting Snowflake (with Workbench credentials), PostgreSQL, SQL Server, etc. +4. **Dynamic connection switching** in the kernel — via meta-commands executed from the connection dialog + +## Architecture Overview + +``` +New Connection Dialog ggsql-jupyter Kernel + (ggsql-vscode) (Rust) + │ │ + │ generateCode() → │ + │ "-- @connect: odbc://snowflake?…" │ + │ │ + │ connect() → │ + │ positron.runtime.executeCode()───┤ + │ │ detect meta-command + │ │ create OdbcReader + │ │ open positron.connection comm + │ │ + │ ◄── comm_open ──────┤ (kernel initiates) + │ │ + Connections Pane │ + │── list_objects([]) ─────────────►│ SELECT … FROM information_schema + │◄─ [{name:"public",kind:"schema"}]│ + │── list_fields([schema,table]) ──►│ SELECT … FROM information_schema.columns + │◄─ [{name:"id",dtype:"integer"}] │ +``` + +Key insight: The **kernel** opens the `positron.connection` comm (backend-initiated), unlike variables/ui comms which are frontend-initiated. + +--- + +## Part 1: ODBC Reader (`src/reader/odbc.rs`) + +### New file: `src/reader/odbc.rs` + +Generic ODBC reader using `odbc-api` crate. Implements `Reader` trait. + +**Connection string format**: `odbc://` prefix + raw ODBC connection string (no URI parsing of the payload) +- `odbc://Driver=Snowflake;Server=myaccount.snowflakecomputing.com;Warehouse=WH` — Snowflake +- `odbc://Driver={PostgreSQL};Server=localhost;Database=mydb` — PostgreSQL +- The extension's driver dialogs build the ODBC string in `generateCode()` and prefix it with `odbc://` +- Parsing: strip `odbc://` prefix, pass remainder directly to `SQLDriverConnect` + +**Core implementation**: +- `OdbcReader::from_connection_string(uri)` — parse URI, detect credentials, connect +- `execute_sql(&self, sql)` — execute via `connection.execute()`, convert cursor → DataFrame +- Cursor → DataFrame conversion: iterate ODBC columnar buffers, map ODBC types to Polars types +- `register()` returns error (ODBC doesn't support temp table registration easily) +- `dialect()` returns dialect variant detected from DBMS info + +**Snowflake Workbench credential detection** (per `~/work/positron/CONNECTIONS.md`): +When `OdbcReader` sees `Driver=Snowflake` in the connection string and no `Token=` is present: +1. Read `SNOWFLAKE_HOME` env var +2. If path contains `"posit-workbench"`, parse `$SNOWFLAKE_HOME/connections.toml` +3. Extract `account` + `token` from `[workbench]` section +4. Inject `Authenticator=oauth;Token=` into the connection string before connecting +5. If no Workbench credentials found, connect as-is (user may have specified auth in the string) + +**Credential storage**: Trust Positron's secret storage — the full `-- @connect:` meta-command (including any credentials in the ODBC string) is stored in the `code` field of the connection comm metadata. Positron persists this in encrypted workspace secret storage for reconnection. + +**OdbcDialect**: Implements `SqlDialect` with a variant enum (Generic, Snowflake, PostgreSQL) detected from DBMS metadata at connection time. + +### Modify: `src/reader/mod.rs` +- Add `#[cfg(feature = "odbc")] pub mod odbc;` and re-export `OdbcReader` +- Remove `where Self: Sized` bound from `fn execute()` + +### Modify: `src/reader/connection.rs` +- Add `ODBC(String)` variant to `ConnectionInfo` enum +- Parse `odbc://` prefix in `parse_connection_string()` + +### Modify: `src/execute/mod.rs` +- Change `prepare_data_with_reader(query: &str, reader: &R)` → `prepare_data_with_reader(query: &str, reader: &dyn Reader)` +- This is safe: all methods called on reader (`execute_sql`, `dialect`, `register`, `unregister`) are object-safe. `materialize_ctes` already takes `&dyn Reader`. + +### Modify: `src/Cargo.toml` +- Add feature: `odbc = ["dep:odbc-api", "dep:toml"]` +- Add dependencies: `odbc-api = { version = "21", optional = true }`, `toml = { version = "0.8", optional = true }` +- Add `"odbc"` to `all-readers` feature list + +--- + +## Part 2: Kernel Connection Comm Protocol (`ggsql-jupyter/`) + +### New file: `ggsql-jupyter/src/connection.rs` + +Module for database schema introspection via the reader. All methods query `information_schema` using `reader.execute_sql()`. + +**Methods**: +- `list_objects(reader, path) -> Vec`: + - Depth depends on dialect — `SqlDialect::has_catalogs()` (true for Snowflake, false for DuckDB/Postgres) + - **Without catalogs** (DuckDB, PostgreSQL): + - `[]` → query `information_schema.schemata` → return schemas + - `[schema]` → query `information_schema.tables WHERE table_schema = ''` → return tables/views + - **With catalogs** (Snowflake): + - `[]` → query `SHOW DATABASES` or `information_schema.schemata` grouped by catalog → return catalogs with `kind = "catalog"` + - `[catalog]` → query `information_schema.schemata WHERE catalog_name = ''` → return schemas + - `[catalog, schema]` → query `information_schema.tables WHERE table_catalog = '' AND table_schema = ''` → return tables/views +- `list_fields(reader, path) -> Vec`: + - **Without catalogs**: `[schema, table]` → query `information_schema.columns` + - **With catalogs**: `[catalog, schema, table]` → query `information_schema.columns` with catalog filter +- `contains_data(path) -> bool`: true when last element has `kind` == "table" or "view" +- **SQL safety**: All interpolated identifiers use standard quote-escaping (`'` → `''`) via a shared `escape_sql_string()` helper +- `get_icon(path) -> String`: return empty string (let Positron use defaults) +- `preview_object(path)`: stub — return null (Data Explorer comm is a separate future feature) +- `get_metadata(reader_uri, name) -> MetadataSchema`: return connection metadata + +**Dialect differences**: DuckDB's default schema is `main` (not `public`). Snowflake has a catalog→schema→table hierarchy. The `SqlDialect` trait gets new optional methods: +- `has_catalogs() -> bool` — false by default, true for Snowflake +- `schema_list_query() -> &str` — override for backends that don't support `information_schema.schemata` +- `default_schema() -> &str` — `"main"` for DuckDB, `"public"` for PostgreSQL, etc. + +### Modify: `ggsql-jupyter/src/kernel.rs` + +**Add connection comm tracking**: +```rust +connection_comm_id: Option, +``` + +**Opening the comm** (kernel-initiated, sent on iopub after a successful `-- @connect:`): +```rust +// Send comm_open on iopub with target_name = "positron.connection" +self.send_iopub("comm_open", json!({ + "comm_id": new_uuid, + "target_name": "positron.connection", + "data": { + "name": display_name, // e.g. "DuckDB (memory)" or "Snowflake (myaccount)" + "language_id": "ggsql", + "host": host, // e.g. "memory" or "myaccount.snowflakecomputing.com" + "type": type_name, // e.g. "DuckDB" or "Snowflake" + "code": meta_command // e.g. "-- @connect: duckdb://memory" + } +}), parent).await?; +``` + +**Handle incoming comm_msg** for connection comm (JSON-RPC methods from Positron): +- Route `list_objects`, `list_fields`, `contains_data`, `get_icon`, `get_metadata` to `connection.rs` functions +- `preview_object`: stub — return null (full Data Explorer comm is a separate future feature) +- Send JSON-RPC responses back on shell via `send_shell_reply("comm_msg", ...)` + +**Handle comm_close**: clear `connection_comm_id` + +**Update comm_info_request**: include connection comm in response + +**Open connection comm on startup**: After kernel initializes with default DuckDB reader, automatically open a `positron.connection` comm so the Connections pane shows the default database immediately. +- Use `create_message(..., None)` for the no-parent startup comm_open (same pattern as `send_status_initial` at kernel.rs:749) +- Add `send_iopub_no_parent()` helper (or generalize `send_iopub` to accept `Option<&JupyterMessage>`) + +**Replacing connection comms on reader switch**: When `-- @connect:` switches readers: +1. If `connection_comm_id` is `Some`, send `comm_close` on iopub for the old comm ID first +2. Clear `connection_comm_id` +3. Open a new comm with a fresh UUID +4. This ensures no stale comm IDs linger and the Connections pane sees the old connection as disconnected + the new one as active + +### Modify: `ggsql-jupyter/src/executor.rs` + +**Make reader swappable**: +- Change `reader: DuckDBReader` → `reader: Box` +- Add `pub fn swap_reader(&mut self, new_reader: Box)` +- Add `pub fn reader(&self) -> &dyn Reader` accessor for connection.rs queries +- For visualization execution: call `self.reader.execute(query)` directly on the `Box` + +**Add meta-command handling**: +- In `execute()`, check if code starts with `-- @connect: ` +- Parse the connection URI from the meta-command +- Call `create_reader(uri)` (shared function) to build the new reader +- Swap the reader via `swap_reader()` +- Return a new `ExecutionResult::ConnectionChanged { uri, display_name }` variant + +**New `create_reader(uri)` function** (in executor.rs or a new `reader_factory.rs`): +- Parse connection string using `ggsql::reader::connection::parse_connection_string()` +- Match on `ConnectionInfo` variant to construct appropriate reader +- Feature-gated: DuckDB (default), SQLite (optional), ODBC (optional) + +### Modify: `ggsql-jupyter/src/main.rs` + +- Add `--reader` CLI arg (default: `"duckdb://memory"`) +- Pass the reader URI to `KernelServer::new(connection, reader_uri)` +- Kernel creates initial reader from this URI + +### Modify: `ggsql-jupyter/src/lib.rs` + +- Add `mod connection;` + +### Modify: `ggsql-jupyter/Cargo.toml` + +Current: `ggsql = { workspace = true, features = ["duckdb", "vegalite"] }` — only DuckDB. + +Add feature flags: +```toml +[features] +default = [] +sqlite = ["ggsql/sqlite"] +odbc = ["ggsql/odbc"] +all-readers = ["sqlite", "odbc"] +``` + +Update ggsql dep: `ggsql = { workspace = true, features = ["duckdb", "vegalite"] }` stays as default (DuckDB always available). + +**`create_reader()` runtime error handling**: When `-- @connect:` requests a reader that isn't compiled in, return a clear error message to the user via execute_reply: +``` +Error: SQLite support is not compiled into this ggsql-jupyter binary. +Rebuild with: cargo build --features sqlite +``` +This uses `#[cfg(feature = "...")]` branches with a fallback error arm per reader type. + +--- + +## Part 3: VS Code Extension Connection Drivers (`ggsql-vscode/`) + +### New file: `ggsql-vscode/src/connections.ts` + +**`createConnectionDrivers(): positron.ConnectionsDriver[]`** + +Returns array of drivers to register. Each driver: +- `generateCode(inputs)` → returns `-- @connect: ` meta-command string +- `connect(code)` → calls `positron.runtime.executeCode('ggsql', code, false)` to send the meta-command to the running kernel + +**DuckDB driver** (`driverId: 'ggsql-duckdb'`): +- Inputs: `database` (string, optional — empty = in-memory) +- generateCode: `-- @connect: duckdb://memory` or `-- @connect: duckdb://` + +**Snowflake driver** (`driverId: 'ggsql-snowflake'`): +- Inputs: `account` (string, required), `warehouse` (string, required), `database` (string, optional), `schema` (string, optional) +- generateCode: builds full ODBC string e.g. `-- @connect: odbc://Driver=Snowflake;Server=.snowflakecomputing.com;Warehouse=` + +**Generic ODBC driver** (`driverId: 'ggsql-odbc'`): +- Inputs: `connection_string` (string, required — raw ODBC connection string) +- generateCode: `-- @connect: odbc://` + +### Modify: `ggsql-vscode/src/extension.ts` + +In `activate()`, after registering the runtime manager: +```typescript +import { createConnectionDrivers } from './connections'; +// ... +const drivers = createConnectionDrivers(); +for (const driver of drivers) { + context.subscriptions.push(positronApi.connections.registerConnectionDriver(driver)); +} +``` + +### Modify: `ggsql-vscode/src/manager.ts` + +- Update `createKernelSpec()` to accept optional `readerUri` parameter +- Pass `--reader ` in spawn args when `readerUri` is provided +- Add `getActiveSession()` method so `connect()` can check if a kernel is running + +--- + +## Implementation Order + +### Phase 1: Kernel meta-commands and dynamic reader switching +1. Modify `executor.rs` — make reader swappable, add meta-command detection, add `create_reader()` +2. Modify `main.rs` — add `--reader` CLI arg, pass to executor +3. Modify `kernel.rs` — handle `ConnectionChanged` result from executor +4. Test: start kernel with `--reader duckdb://memory`, verify meta-command works + +### Phase 2: Connection comm protocol +5. Create `connection.rs` — schema introspection via information_schema +6. Modify `kernel.rs` — open `positron.connection` comm on startup and after `-- @connect:`, handle incoming JSON-RPC methods +7. Test: start kernel in Positron, verify Connections pane shows DuckDB schema + +### Phase 3: ODBC reader +8. Create `src/reader/odbc.rs` — generic ODBC reader with cursor→DataFrame conversion +9. Add Workbench Snowflake credential detection +10. Modify `connection.rs`, `mod.rs`, `Cargo.toml` for ODBC feature +11. Test: connect to local ODBC data source, verify queries work + +### Phase 4: Extension connection drivers +12. Create `ggsql-vscode/src/connections.ts` — DuckDB, Snowflake, generic ODBC drivers +13. Modify `extension.ts` — register drivers on activation +14. Test: open New Connection dialog, create DuckDB connection, verify Connections pane updates + +### Phase 5: Integration & polish +15. End-to-end test: New Connection dialog → kernel connection → Connections pane browsing +16. Handle edge cases: connection failures, reader not compiled in, comm lifecycle + +--- + +## Verification + +1. **Unit tests**: Meta-command parsing, ODBC URI parsing, Workbench credential detection, schema introspection queries +2. **Integration test**: Start kernel with `--reader duckdb://memory`, execute `-- @connect: duckdb://memory`, verify comm_open message on iopub +3. **Manual Positron test**: Open ggsql session → Connections pane shows DuckDB → expand to see schemas/tables/columns → New Connection dialog → create Snowflake connection → Connections pane updates +4. **Existing tests**: Run `cargo test` to ensure no regressions in parser/reader/writer + +## Key files to modify + +| File | Change | +|------|--------| +| `src/execute/mod.rs` | Change `prepare_data_with_reader` to `&dyn Reader` | +| `src/reader/mod.rs` | Remove `Self: Sized` from `execute()`, add odbc module | +| `src/reader/odbc.rs` | **NEW** — Generic ODBC reader | +| `src/reader/connection.rs` | Add ODBC variant | +| `src/Cargo.toml` | Add odbc feature + deps | +| `ggsql-jupyter/src/connection.rs` | **NEW** — Schema introspection | +| `ggsql-jupyter/src/kernel.rs` | Connection comm protocol | +| `ggsql-jupyter/src/executor.rs` | Dynamic reader switching, meta-commands | +| `ggsql-jupyter/src/main.rs` | `--reader` CLI arg | +| `ggsql-jupyter/src/lib.rs` | Add connection module | +| `ggsql-jupyter/Cargo.toml` | Add odbc feature | +| `ggsql-vscode/src/connections.ts` | **NEW** — Connection drivers | +| `ggsql-vscode/src/extension.ts` | Register connection drivers | +| `ggsql-vscode/src/manager.ts` | Pass reader URI to kernel | diff --git a/ggsql-jupyter/Cargo.toml b/ggsql-jupyter/Cargo.toml index 5b31ef8f..6f4bb1f0 100644 --- a/ggsql-jupyter/Cargo.toml +++ b/ggsql-jupyter/Cargo.toml @@ -56,6 +56,10 @@ hex = "0.4" # UUID for message IDs uuid = { version = "1.0", features = ["v4"] } +[features] +default = [] +odbc = ["ggsql/odbc"] + [dev-dependencies] # Test utilities tokio-test = "0.4" diff --git a/ggsql-jupyter/src/connection.rs b/ggsql-jupyter/src/connection.rs new file mode 100644 index 00000000..a1b3085d --- /dev/null +++ b/ggsql-jupyter/src/connection.rs @@ -0,0 +1,209 @@ +//! Database schema introspection for the Positron Connections pane. +//! +//! Delegates introspection SQL to the reader's `SqlDialect`, which provides +//! backend-specific queries (e.g. `information_schema` for DuckDB/PostgreSQL, +//! `sqlite_master` / `PRAGMA` for SQLite). + +use ggsql::reader::Reader; +use serde::Serialize; +use serde_json::Value; + +/// An object in the schema hierarchy (catalog, schema, table, or view). +#[derive(Debug, Serialize)] +pub struct ObjectSchema { + pub name: String, + pub kind: String, +} + +/// A field (column) in a table. +#[derive(Debug, Serialize)] +pub struct FieldSchema { + pub name: String, + pub dtype: String, +} + +/// List objects at the given path depth. +/// +/// Path semantics (catalog → schema → table): +/// - `[]` → list catalogs +/// - `[catalog]` → list schemas in that catalog +/// - `[catalog, schema]` → list tables and views +pub fn list_objects(reader: &dyn Reader, path: &[String]) -> Result, String> { + match path.len() { + 0 => list_catalogs(reader), + 1 => list_schemas(reader, &path[0]), + 2 => list_tables(reader, &path[0], &path[1]), + _ => Ok(vec![]), + } +} + +/// List fields (columns) for the object at the given path. +/// +/// - `[catalog, schema, table]` → list columns +pub fn list_fields(reader: &dyn Reader, path: &[String]) -> Result, String> { + if path.len() == 3 { + list_columns(reader, &path[0], &path[1], &path[2]) + } else { + Ok(vec![]) + } +} + +/// Whether the path points to an object that contains data (table or view). +pub fn contains_data(path: &[Value]) -> bool { + path.last() + .and_then(|v| v.get("kind")) + .and_then(|k| k.as_str()) + .map(|k| k == "table" || k == "view") + .unwrap_or(false) +} + +fn list_catalogs(reader: &dyn Reader) -> Result, String> { + let sql = reader.dialect().sql_list_catalogs(); + let df = reader + .execute_sql(&sql) + .map_err(|e| format!("Failed to list catalogs: {}", e))?; + + let col = df + .column("catalog_name") + .map_err(|e| format!("Missing catalog_name column: {}", e))?; + + let mut catalogs = Vec::new(); + for i in 0..df.height() { + if let Ok(val) = col.get(i) { + let name = val.to_string().trim_matches('"').to_string(); + catalogs.push(ObjectSchema { + name, + kind: "catalog".to_string(), + }); + } + } + Ok(catalogs) +} + +fn list_schemas(reader: &dyn Reader, catalog: &str) -> Result, String> { + let sql = reader.dialect().sql_list_schemas(catalog); + let df = reader + .execute_sql(&sql) + .map_err(|e| format!("Failed to list schemas: {}", e))?; + + let col = df + .column("schema_name") + .map_err(|e| format!("Missing schema_name column: {}", e))?; + + let mut schemas = Vec::new(); + for i in 0..df.height() { + if let Ok(val) = col.get(i) { + let name = val.to_string().trim_matches('"').to_string(); + schemas.push(ObjectSchema { + name, + kind: "schema".to_string(), + }); + } + } + Ok(schemas) +} + +fn list_tables( + reader: &dyn Reader, + catalog: &str, + schema: &str, +) -> Result, String> { + let sql = reader.dialect().sql_list_tables(catalog, schema); + let df = reader + .execute_sql(&sql) + .map_err(|e| format!("Failed to list tables: {}", e))?; + + let name_col = df + .column("table_name") + .map_err(|e| format!("Missing table_name column: {}", e))?; + let type_col = df + .column("table_type") + .map_err(|e| format!("Missing table_type column: {}", e))?; + + let mut objects = Vec::new(); + for i in 0..df.height() { + if let (Ok(name_val), Ok(type_val)) = (name_col.get(i), type_col.get(i)) { + let name = name_val.to_string().trim_matches('"').to_string(); + let table_type = type_val.to_string().trim_matches('"').to_uppercase(); + let kind = if table_type.contains("VIEW") { + "view" + } else { + "table" + }; + objects.push(ObjectSchema { + name, + kind: kind.to_string(), + }); + } + } + Ok(objects) +} + +fn list_columns( + reader: &dyn Reader, + catalog: &str, + schema: &str, + table: &str, +) -> Result, String> { + let sql = reader.dialect().sql_list_columns(catalog, schema, table); + let df = reader + .execute_sql(&sql) + .map_err(|e| format!("Failed to list columns: {}", e))?; + + let name_col = df + .column("column_name") + .map_err(|e| format!("Missing column_name column: {}", e))?; + let type_col = df + .column("data_type") + .map_err(|e| format!("Missing data_type column: {}", e))?; + + let mut fields = Vec::new(); + for i in 0..df.height() { + if let (Ok(name_val), Ok(type_val)) = (name_col.get(i), type_col.get(i)) { + let name = name_val.to_string().trim_matches('"').to_string(); + let dtype = type_val.to_string().trim_matches('"').to_string(); + fields.push(FieldSchema { name, dtype }); + } + } + Ok(fields) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_contains_data_table() { + let path = vec![ + serde_json::json!({"name": "memory", "kind": "catalog"}), + serde_json::json!({"name": "main", "kind": "schema"}), + serde_json::json!({"name": "users", "kind": "table"}), + ]; + assert!(contains_data(&path)); + } + + #[test] + fn test_contains_data_schema() { + let path = vec![ + serde_json::json!({"name": "memory", "kind": "catalog"}), + serde_json::json!({"name": "main", "kind": "schema"}), + ]; + assert!(!contains_data(&path)); + } + + #[test] + fn test_contains_data_catalog() { + let path = vec![serde_json::json!({"name": "memory", "kind": "catalog"})]; + assert!(!contains_data(&path)); + } + + #[test] + fn test_contains_data_view() { + let path = vec![ + serde_json::json!({"name": "memory", "kind": "catalog"}), + serde_json::json!({"name": "main", "kind": "schema"}), + serde_json::json!({"name": "my_view", "kind": "view"}), + ]; + assert!(contains_data(&path)); + } +} diff --git a/ggsql-jupyter/src/display.rs b/ggsql-jupyter/src/display.rs index 6406441b..663d5b79 100644 --- a/ggsql-jupyter/src/display.rs +++ b/ggsql-jupyter/src/display.rs @@ -21,9 +21,24 @@ pub fn format_display_data(result: ExecutionResult) -> Value { match result { ExecutionResult::Visualization { spec } => format_vegalite(spec), ExecutionResult::DataFrame(df) => format_dataframe(df), + ExecutionResult::ConnectionChanged { display_name, .. } => { + format_connection_changed(&display_name) + } } } +/// Format a connection-changed message +fn format_connection_changed(display_name: &str) -> Value { + let text = format!("Connected to {}", display_name); + json!({ + "data": { + "text/plain": text + }, + "metadata": {}, + "transient": {} + }) +} + /// Format Vega-Lite visualization as display_data fn format_vegalite(spec: String) -> Value { let spec_value: Value = serde_json::from_str(&spec).unwrap_or_else(|e| { diff --git a/ggsql-jupyter/src/executor.rs b/ggsql-jupyter/src/executor.rs index 1548f5e3..9eb3448a 100644 --- a/ggsql-jupyter/src/executor.rs +++ b/ggsql-jupyter/src/executor.rs @@ -2,10 +2,11 @@ //! //! This module handles the execution of ggsql queries using the existing //! ggsql library components (parser, DuckDB reader, Vega-Lite writer). +//! It supports dynamic reader switching via `-- @connect:` meta-commands. use anyhow::Result; use ggsql::{ - reader::{DuckDBReader, Reader}, + reader::{connection::parse_connection_string, DuckDBReader, Reader}, validate::validate, writer::{VegaLiteWriter, Writer}, }; @@ -20,39 +21,186 @@ pub enum ExecutionResult { Visualization { spec: String, // Vega-Lite JSON }, + /// Connection changed via meta-command + ConnectionChanged { + uri: String, + display_name: String, + }, +} + +/// Create a reader from a connection URI string. +/// +/// Supported schemes: +/// - `duckdb://memory` or `duckdb://` (always available) +/// - `odbc://...` (requires `odbc` feature) +pub fn create_reader(uri: &str) -> Result> { + use ggsql::reader::connection::ConnectionInfo; + + let info = parse_connection_string(uri)?; + match info { + ConnectionInfo::DuckDBMemory => { + let reader = DuckDBReader::from_connection_string("duckdb://memory")?; + Ok(Box::new(reader)) + } + ConnectionInfo::DuckDBFile(path) => { + let reader = + DuckDBReader::from_connection_string(&format!("duckdb://{}", path))?; + Ok(Box::new(reader)) + } + #[cfg(feature = "odbc")] + ConnectionInfo::ODBC(conn_str) => { + let reader = ggsql::reader::OdbcReader::from_connection_string( + &format!("odbc://{}", conn_str), + )?; + Ok(Box::new(reader)) + } + _ => anyhow::bail!( + "Unsupported reader type for connection string: {}. \ + Only DuckDB connections are currently supported in ggsql-jupyter.", + uri + ), + } +} + +/// Generate a human-readable display name for a connection URI. +pub fn display_name_for_uri(uri: &str) -> String { + if uri == "duckdb://memory" { + return "DuckDB (memory)".to_string(); + } + if let Some(path) = uri.strip_prefix("duckdb://") { + return format!("DuckDB ({})", path); + } + if let Some(odbc) = uri.strip_prefix("odbc://") { + // Try to extract driver name from ODBC string + if let Some(driver_start) = odbc.to_lowercase().find("driver=") { + let rest = &odbc[driver_start + 7..]; + let driver = rest + .split(';') + .next() + .unwrap_or("ODBC") + .trim_matches(|c| c == '{' || c == '}'); + return format!("{} (ODBC)", driver); + } + return "ODBC".to_string(); + } + uri.to_string() +} + +/// Detect the database type name from a connection URI (e.g. "DuckDB", "Snowflake"). +pub fn type_name_for_uri(uri: &str) -> String { + if uri.starts_with("duckdb://") { + return "DuckDB".to_string(); + } + if let Some(odbc) = uri.strip_prefix("odbc://") { + if odbc.to_lowercase().contains("driver=snowflake") { + return "Snowflake".to_string(); + } + if odbc.to_lowercase().contains("driver={postgresql}") || odbc.to_lowercase().contains("driver=postgresql") { + return "PostgreSQL".to_string(); + } + return "ODBC".to_string(); + } + "Unknown".to_string() +} + +/// Extract the host portion from a connection URI. +pub fn host_for_uri(uri: &str) -> String { + if uri == "duckdb://memory" { + return "memory".to_string(); + } + if let Some(path) = uri.strip_prefix("duckdb://") { + return path.to_string(); + } + if let Some(odbc) = uri.strip_prefix("odbc://") { + // Try to extract server + if let Some(server_start) = odbc.to_lowercase().find("server=") { + let rest = &odbc[server_start + 7..]; + if let Some(host) = rest.split(';').next() { + return host.to_string(); + } + } + } + uri.to_string() +} + +/// The `-- @connect:` meta-command prefix. +const META_CONNECT_PREFIX: &str = "-- @connect:"; + +/// Parse a `-- @connect: ` meta-command, returning the URI if present. +pub fn parse_meta_command(code: &str) -> Option { + let trimmed = code.trim(); + if let Some(rest) = trimmed.strip_prefix(META_CONNECT_PREFIX) { + Some(rest.trim().to_string()) + } else { + None + } } -/// Query executor maintaining persistent DuckDB connection +/// Query executor maintaining persistent database connection pub struct QueryExecutor { - reader: DuckDBReader, + reader: Box, writer: VegaLiteWriter, + reader_uri: String, } impl QueryExecutor { - /// Create a new query executor with in-memory DuckDB database - pub fn new() -> Result { - tracing::info!("Initializing query executor with in-memory DuckDB"); - let reader = DuckDBReader::from_connection_string("duckdb://memory")?; + /// Create a new query executor with a given connection URI + pub fn new_with_uri(uri: &str) -> Result { + tracing::info!("Initializing query executor with reader: {}", uri); + let reader = create_reader(uri)?; let writer = VegaLiteWriter::new(); - Ok(Self { reader, writer }) + Ok(Self { + reader, + writer, + reader_uri: uri.to_string(), + }) } - /// Execute a ggsql query - /// - /// This handles both pure SQL queries and queries with VISUALISE clauses. - /// - /// # Arguments - /// - /// * `code` - The ggsql query to execute - /// - /// # Returns + /// Create a new query executor with the default in-memory DuckDB database + #[cfg(test)] + pub fn new() -> Result { + Self::new_with_uri("duckdb://memory") + } + + /// Get the current reader URI + pub fn reader_uri(&self) -> &str { + &self.reader_uri + } + + /// Get a reference to the current reader (for schema introspection) + pub fn reader(&self) -> &dyn Reader { + &*self.reader + } + + /// Swap the reader to a new connection, returning the old URI + pub fn swap_reader(&mut self, uri: &str) -> Result { + let new_reader = create_reader(uri)?; + self.reader = new_reader; + let old_uri = std::mem::replace(&mut self.reader_uri, uri.to_string()); + Ok(old_uri) + } + + /// Execute a ggsql query or meta-command /// - /// An ExecutionResult containing either a DataFrame (for pure SQL) or - /// a Visualization (for queries with VISUALISE clause) - pub fn execute(&self, code: &str) -> Result { + /// This handles: + /// - `-- @connect: ` meta-commands for switching readers + /// - Pure SQL queries (no VISUALISE) + /// - ggsql queries with VISUALISE clauses + pub fn execute(&mut self, code: &str) -> Result { tracing::debug!("Executing query: {} chars", code.len()); + // Check for meta-commands first + if let Some(uri) = parse_meta_command(code) { + tracing::info!("Meta-command: switching reader to {}", uri); + self.swap_reader(&uri)?; + let display_name = display_name_for_uri(&uri); + return Ok(ExecutionResult::ConnectionChanged { + uri, + display_name, + }); + } + // 1. Validate to check if there's a visualization let validated = validate(code)?; @@ -93,7 +241,7 @@ mod tests { #[test] fn test_simple_visualization() { - let executor = QueryExecutor::new().unwrap(); + let mut executor = QueryExecutor::new().unwrap(); let code = "SELECT 1 as x, 2 as y VISUALISE x, y DRAW point"; let result = executor.execute(code).unwrap(); @@ -102,7 +250,7 @@ mod tests { #[test] fn test_pure_sql() { - let executor = QueryExecutor::new().unwrap(); + let mut executor = QueryExecutor::new().unwrap(); let code = "SELECT 1 as x, 2 as y"; let result = executor.execute(code).unwrap(); @@ -111,10 +259,38 @@ mod tests { #[test] fn test_error_handling() { - let executor = QueryExecutor::new().unwrap(); + let mut executor = QueryExecutor::new().unwrap(); let code = "SELECT * FROM nonexistent_table"; let result = executor.execute(code); assert!(result.is_err()); } + + #[test] + fn test_parse_meta_command() { + assert_eq!( + parse_meta_command("-- @connect: duckdb://memory"), + Some("duckdb://memory".to_string()) + ); + assert_eq!( + parse_meta_command(" -- @connect: duckdb://my.db "), + Some("duckdb://my.db".to_string()) + ); + assert_eq!(parse_meta_command("SELECT 1"), None); + } + + #[test] + fn test_meta_command_switches_reader() { + let mut executor = QueryExecutor::new().unwrap(); + assert_eq!(executor.reader_uri(), "duckdb://memory"); + + let result = executor.execute("-- @connect: duckdb://memory").unwrap(); + assert!(matches!(result, ExecutionResult::ConnectionChanged { .. })); + } + + #[test] + fn test_display_name_for_uri() { + assert_eq!(display_name_for_uri("duckdb://memory"), "DuckDB (memory)"); + assert_eq!(display_name_for_uri("duckdb://my.db"), "DuckDB (my.db)"); + } } diff --git a/ggsql-jupyter/src/kernel.rs b/ggsql-jupyter/src/kernel.rs index 659d95ca..8eda7723 100644 --- a/ggsql-jupyter/src/kernel.rs +++ b/ggsql-jupyter/src/kernel.rs @@ -3,8 +3,9 @@ //! This module implements the Jupyter messaging protocol over ZeroMQ sockets, //! handling kernel_info, execute, and shutdown requests. +use crate::connection; use crate::display::format_display_data; -use crate::executor::QueryExecutor; +use crate::executor::{self, ExecutionResult, QueryExecutor}; use crate::message::{ConnectionInfo, JupyterMessage, MessageHeader}; use anyhow::Result; use hmac::{Hmac, Mac}; @@ -32,11 +33,12 @@ pub struct KernelServer { variables_comm_id: Option, ui_comm_id: Option, plot_comm_id: Option, + connection_comm_id: Option, } impl KernelServer { /// Create a new kernel server from connection info - pub async fn new(connection: ConnectionInfo) -> Result { + pub async fn new(connection: ConnectionInfo, reader_uri: &str) -> Result { tracing::info!("Initializing kernel server"); // Initialize sockets @@ -68,8 +70,8 @@ impl KernelServer { tracing::info!("Binding heartbeat socket to {}", hb_addr); heartbeat.bind(&hb_addr).await?; - // Create executor - let executor = QueryExecutor::new()?; + // Create executor with the specified reader + let executor = QueryExecutor::new_with_uri(reader_uri)?; // Generate session ID let session = uuid::Uuid::new_v4().to_string(); @@ -92,12 +94,16 @@ impl KernelServer { variables_comm_id: None, ui_comm_id: None, plot_comm_id: None, + connection_comm_id: None, }; // Send initial "starting" status on IOPub // This is required by Jupyter protocol - exactly once at process startup kernel.send_status_initial("starting").await?; + // Open initial connection comm so the Connections pane shows the database + kernel.open_connection_comm(reader_uri).await?; + Ok(kernel) } @@ -310,9 +316,15 @@ impl KernelServer { match result { Ok(exec_result) => { + // If the connection changed, open a new connection comm + let is_connection_changed = matches!(&exec_result, ExecutionResult::ConnectionChanged { .. }); + if let ExecutionResult::ConnectionChanged { ref uri, .. } = &exec_result { + self.open_connection_comm(uri).await?; + } + // Send execute_result (not display_data) // Per Jupyter spec: execute_result includes execution_count - if !silent { + if !silent && !is_connection_changed { let display_data = format_display_data(exec_result); // Build message content, including output_location if present @@ -593,6 +605,11 @@ impl KernelServer { else if Some(comm_id.to_string()) == self.plot_comm_id { tracing::info!("Received plot request: {} (ignoring)", method); } + // Handle positron.connection requests + else if Some(comm_id.to_string()) == self.connection_comm_id { + self.handle_connection_rpc(method, rpc_id, comm_id, parent, identities) + .await?; + } // Unknown comm else { tracing::warn!("Message for unknown comm_id: {}", comm_id); @@ -633,6 +650,11 @@ impl KernelServer { comms[id] = json!({"target_name": "positron.plot"}); } } + if let Some(id) = &self.connection_comm_id { + if target_name.is_none() || target_name == Some("positron.connection") { + comms[id] = json!({"target_name": "positron.connection"}); + } + } tracing::info!( "Returning comms: {}", @@ -676,6 +698,9 @@ impl KernelServer { } else if Some(comm_id.to_string()) == self.plot_comm_id { tracing::info!("Closing positron.plot comm"); self.plot_comm_id = None; + } else if Some(comm_id.to_string()) == self.connection_comm_id { + tracing::info!("Closing positron.connection comm"); + self.connection_comm_id = None; } else { tracing::warn!("Close for unknown comm_id: {}", comm_id); } @@ -684,6 +709,157 @@ impl KernelServer { Ok(()) } + /// Open (or replace) a `positron.connection` comm for the current reader. + /// + /// The kernel initiates this comm (backend-initiated). If an existing + /// connection comm is open, it is closed first. + async fn open_connection_comm(&mut self, uri: &str) -> Result<()> { + // Close existing connection comm if any + if let Some(old_id) = self.connection_comm_id.take() { + tracing::info!("Closing old connection comm: {}", old_id); + let close_msg = self.create_message( + "comm_close", + json!({ "comm_id": old_id }), + None, + ); + let zmq_msg = self.serialize_message_with_topic(&close_msg, "comm_close")?; + self.iopub.send(zmq_msg).await?; + } + + let comm_id = uuid::Uuid::new_v4().to_string(); + let display_name = executor::display_name_for_uri(uri); + let type_name = executor::type_name_for_uri(uri); + let host = executor::host_for_uri(uri); + let meta_command = format!("-- @connect: {}", uri); + + tracing::info!( + "Opening positron.connection comm: {} ({})", + comm_id, + display_name + ); + + let msg = self.create_message( + "comm_open", + json!({ + "comm_id": comm_id, + "target_name": "positron.connection", + "data": { + "name": display_name, + "language_id": "ggsql", + "host": host, + "type": type_name, + "code": meta_command + } + }), + None, + ); + let zmq_msg = self.serialize_message_with_topic(&msg, "comm_open")?; + self.iopub.send(zmq_msg).await?; + + self.connection_comm_id = Some(comm_id); + Ok(()) + } + + /// Handle JSON-RPC requests on the connection comm + async fn handle_connection_rpc( + &mut self, + method: &str, + rpc_id: &Value, + comm_id: &str, + parent: &JupyterMessage, + identities: &[Vec], + ) -> Result<()> { + tracing::info!("Connection RPC: {}", method); + + let params = &parent.content["data"]["params"]; + + let result = match method { + "list_objects" => { + let path: Vec = params["path"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|v| { + v.get("name") + .and_then(|n| n.as_str()) + .map(|s| s.to_string()) + }) + .collect() + }) + .unwrap_or_default(); + match connection::list_objects(self.executor.reader(), &path) { + Ok(objects) => json!(objects), + Err(e) => { + tracing::error!("list_objects failed: {}", e); + json!([]) + } + } + } + "list_fields" => { + let path: Vec = params["path"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|v| { + v.get("name") + .and_then(|n| n.as_str()) + .map(|s| s.to_string()) + }) + .collect() + }) + .unwrap_or_default(); + match connection::list_fields(self.executor.reader(), &path) { + Ok(fields) => json!(fields), + Err(e) => { + tracing::error!("list_fields failed: {}", e); + json!([]) + } + } + } + "contains_data" => { + let path: Vec = params["path"] + .as_array() + .cloned() + .unwrap_or_default(); + let has_data = connection::contains_data(&path); + json!(has_data) + } + "get_icon" => json!(""), + "preview_object" => json!(null), + "get_metadata" => { + let uri = self.executor.reader_uri(); + json!({ + "name": executor::display_name_for_uri(uri), + "language_id": "ggsql", + "host": executor::host_for_uri(uri), + "type": executor::type_name_for_uri(uri), + "code": format!("-- @connect: {}", uri) + }) + } + _ => { + tracing::warn!("Unknown connection method: {}", method); + json!(null) + } + }; + + self.send_shell_reply( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "id": rpc_id, + "result": result + } + }), + parent, + identities, + ) + .await?; + + Ok(()) + } + /// Send a message on the IOPub channel async fn send_iopub( &mut self, diff --git a/ggsql-jupyter/src/lib.rs b/ggsql-jupyter/src/lib.rs index 40861748..426f767d 100644 --- a/ggsql-jupyter/src/lib.rs +++ b/ggsql-jupyter/src/lib.rs @@ -2,6 +2,7 @@ //! //! This module exposes the internal components for testing. +pub mod connection; pub mod display; pub mod executor; pub mod message; diff --git a/ggsql-jupyter/src/main.rs b/ggsql-jupyter/src/main.rs index 316ab8ba..9417e791 100644 --- a/ggsql-jupyter/src/main.rs +++ b/ggsql-jupyter/src/main.rs @@ -2,6 +2,7 @@ //! //! A Jupyter kernel for executing ggsql queries with rich Vega-Lite visualizations. +mod connection; mod display; mod executor; mod kernel; @@ -22,6 +23,10 @@ struct Args { #[arg(short = 'f', long = "connection-file")] connection_file: Option, + /// Database connection URI (e.g. "duckdb://memory") + #[arg(long, default_value = "duckdb://memory")] + reader: String, + /// Install the kernel spec #[arg(long)] install: bool, @@ -69,7 +74,7 @@ async fn main() -> Result<()> { tracing::info!("Creating kernel server"); // Create and run kernel - let mut kernel = kernel::KernelServer::new(connection).await?; + let mut kernel = kernel::KernelServer::new(connection, &args.reader).await?; tracing::info!("Kernel ready, starting event loop"); diff --git a/ggsql-python/src/lib.rs b/ggsql-python/src/lib.rs index d477275e..6e5c543c 100644 --- a/ggsql-python/src/lib.rs +++ b/ggsql-python/src/lib.rs @@ -164,6 +164,10 @@ impl Reader for PyReaderBridge { }) } + fn execute(&self, query: &str) -> ggsql::Result { + ggsql::reader::execute_with_reader(self, query) + } + fn dialect(&self) -> &dyn ggsql::reader::SqlDialect { &ANSI_DIALECT } diff --git a/ggsql-vscode/src/connections.ts b/ggsql-vscode/src/connections.ts new file mode 100644 index 00000000..9c2b7e11 --- /dev/null +++ b/ggsql-vscode/src/connections.ts @@ -0,0 +1,151 @@ +/* + * Connection Drivers for Positron's Connections pane + * + * Registers drivers that let users create database connections via the + * "New Connection" dialog. Each driver generates a `-- @connect:` meta-command + * that the ggsql-jupyter kernel interprets to switch readers. + */ + +import type * as positron from '@posit-dev/positron'; + +type PositronApi = positron.PositronApi; + +/** + * Create the set of ggsql connection drivers to register with Positron. + */ +export function createConnectionDrivers( + positronApi: PositronApi +): positron.ConnectionsDriver[] { + return [ + createDuckDBDriver(positronApi), + createSnowflakeDriver(positronApi), + createOdbcDriver(positronApi), + ]; +} + +/** + * DuckDB connection driver. + * + * Inputs: optional database file path (empty = in-memory). + */ +function createDuckDBDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-duckdb', + metadata: { + languageId: 'ggsql', + name: 'DuckDB', + inputs: [ + { + id: 'database', + label: 'Database', + type: 'string', + value: '', + }, + ], + }, + generateCode: (inputs) => { + const db = inputs.find((i) => i.id === 'database')?.value?.trim(); + if (!db) { + return '-- @connect: duckdb://memory'; + } + return `-- @connect: duckdb://${db}`; + }, + connect: async (code: string) => { + await positronApi.runtime.executeCode('ggsql', code, false); + }, + }; +} + +/** + * Snowflake connection driver (via ODBC). + * + * Builds an ODBC connection string targeting the Snowflake ODBC driver. + * Workbench OAuth token injection happens automatically in the kernel. + */ +function createSnowflakeDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-snowflake', + metadata: { + languageId: 'ggsql', + name: 'Snowflake', + inputs: [ + { + id: 'account', + label: 'Account', + type: 'string', + }, + { + id: 'warehouse', + label: 'Warehouse', + type: 'string', + }, + { + id: 'database', + label: 'Database', + type: 'string', + value: '', + }, + { + id: 'schema', + label: 'Schema', + type: 'string', + value: '', + }, + ], + }, + generateCode: (inputs) => { + const account = inputs.find((i) => i.id === 'account')?.value ?? ''; + const warehouse = inputs.find((i) => i.id === 'warehouse')?.value ?? ''; + const database = inputs.find((i) => i.id === 'database')?.value ?? ''; + const schema = inputs.find((i) => i.id === 'schema')?.value ?? ''; + + let connStr = `Driver=Snowflake;Server=${account}.snowflakecomputing.com;Warehouse=${warehouse}`; + if (database) { + connStr += `;Database=${database}`; + } + if (schema) { + connStr += `;Schema=${schema}`; + } + return `-- @connect: odbc://${connStr}`; + }, + connect: async (code: string) => { + await positronApi.runtime.executeCode('ggsql', code, false); + }, + }; +} + +/** + * Generic ODBC connection driver. + * + * Lets users paste a raw ODBC connection string. + */ +function createOdbcDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-odbc', + metadata: { + languageId: 'ggsql', + name: 'ODBC', + inputs: [ + { + id: 'connection_string', + label: 'Connection String', + type: 'string', + }, + ], + }, + generateCode: (inputs) => { + const connStr = + inputs.find((i) => i.id === 'connection_string')?.value ?? ''; + return `-- @connect: odbc://${connStr}`; + }, + connect: async (code: string) => { + await positronApi.runtime.executeCode('ggsql', code, false); + }, + }; +} diff --git a/ggsql-vscode/src/extension.ts b/ggsql-vscode/src/extension.ts index 6d6d1ec5..54b76edf 100644 --- a/ggsql-vscode/src/extension.ts +++ b/ggsql-vscode/src/extension.ts @@ -8,6 +8,7 @@ import * as vscode from 'vscode'; import { tryAcquirePositronApi } from '@posit-dev/positron'; import { GgsqlRuntimeManager } from './manager'; +import { createConnectionDrivers } from './connections'; // Output channel for logging const outputChannel = vscode.window.createOutputChannel('ggsql'); @@ -42,6 +43,15 @@ export function activate(context: vscode.ExtensionContext): void { context.subscriptions.push(disposable); log('ggsql runtime manager registered successfully'); + + // Register connection drivers for the Connections pane + const drivers = createConnectionDrivers(positronApi); + for (const driver of drivers) { + const driverDisposable = positronApi.connections.registerConnectionDriver(driver); + context.subscriptions.push(driverDisposable); + } + + log(`Registered ${drivers.length} connection drivers`); } /** diff --git a/ggsql-vscode/src/manager.ts b/ggsql-vscode/src/manager.ts index c67a8924..863c4bcf 100644 --- a/ggsql-vscode/src/manager.ts +++ b/ggsql-vscode/src/manager.ts @@ -118,7 +118,7 @@ function generateMetadata( * * @param workspacePath - Optional workspace path to use as the kernel's working directory */ -function createKernelSpec(workspacePath?: string): JupyterKernelSpec { +function createKernelSpec(workspacePath?: string, readerUri?: string): JupyterKernelSpec { const kernelPath = getKernelPath(); return { @@ -132,11 +132,20 @@ function createKernelSpec(workspacePath?: string): JupyterKernelSpec { startKernel: async (session: JupyterSession, kernel: JupyterKernel) => { kernel.log(`Starting ggsql kernel with connection file: ${session.state.connectionFile}`); kernel.log(`Working directory: ${workspacePath ?? 'inherited from parent'}`); + if (readerUri) { + kernel.log(`Reader URI: ${readerUri}`); + } const connectionFile = session.state.connectionFile; + // Build arguments + const args = ['-f', connectionFile]; + if (readerUri) { + args.push('--reader', readerUri); + } + // Start the kernel process - const proc = cp.spawn(kernelPath, ['-f', connectionFile], { + const proc = cp.spawn(kernelPath, args, { stdio: ['ignore', 'pipe', 'pipe'], detached: false, cwd: workspacePath diff --git a/src/Cargo.toml b/src/Cargo.toml index f2c85bc6..64ad22f1 100644 --- a/src/Cargo.toml +++ b/src/Cargo.toml @@ -39,6 +39,8 @@ duckdb = { workspace = true, optional = true } arrow = { workspace = true, optional = true } postgres = { workspace = true, optional = true } rusqlite = { workspace = true, optional = true } +odbc-api = { workspace = true, optional = true } +toml_edit = { workspace = true, optional = true } # Writers plotters = { workspace = true, optional = true } @@ -83,12 +85,13 @@ polars-sql = ["polars/sql"] parquet = ["polars/parquet"] postgres = ["dep:postgres"] sqlite = ["dep:rusqlite"] +odbc = ["dep:odbc-api", "dep:toml_edit"] vegalite = [] ggplot2 = [] builtin-data = [] python = ["dep:pyo3"] rest-api = ["dep:axum", "dep:tokio", "dep:tower-http", "dep:tracing", "dep:tracing-subscriber", "duckdb", "vegalite"] -all-readers = ["duckdb", "postgres", "sqlite", "polars-sql"] +all-readers = ["duckdb", "postgres", "sqlite", "polars-sql", "odbc"] all-writers = ["vegalite", "ggplot2", "plotters"] # cargo-packager configuration for cross-platform installers diff --git a/src/execute/mod.rs b/src/execute/mod.rs index 625a17b4..b32d4e76 100644 --- a/src/execute/mod.rs +++ b/src/execute/mod.rs @@ -888,7 +888,7 @@ pub struct PreparedData { /// # Arguments /// * `query` - The full ggsql query string /// * `reader` - A Reader implementation for executing SQL -pub fn prepare_data_with_reader(query: &str, reader: &R) -> Result { +pub fn prepare_data_with_reader(query: &str, reader: &dyn Reader) -> Result { let execute_query = |sql: &str| reader.execute_sql(sql); let dialect = reader.dialect(); diff --git a/src/reader/connection.rs b/src/reader/connection.rs index e592d5cc..8088403e 100644 --- a/src/reader/connection.rs +++ b/src/reader/connection.rs @@ -19,6 +19,9 @@ pub enum ConnectionInfo { /// SQLite file-based database #[allow(dead_code)] SQLite(String), + /// Generic ODBC connection (raw connection string after `odbc://` prefix) + #[allow(dead_code)] + ODBC(String), } /// Parse a connection string into connection information @@ -84,8 +87,17 @@ pub fn parse_connection_string(uri: &str) -> Result { return Ok(ConnectionInfo::SQLite(cleaned_path.to_string())); } + if let Some(conn_str) = uri.strip_prefix("odbc://") { + if conn_str.is_empty() { + return Err(GgsqlError::ReaderError( + "ODBC connection string cannot be empty".to_string(), + )); + } + return Ok(ConnectionInfo::ODBC(conn_str.to_string())); + } + Err(GgsqlError::ReaderError(format!( - "Unsupported connection string format: {}. Supported: duckdb://, polars://, postgres://, sqlite://", + "Unsupported connection string format: {}. Supported: duckdb://, polars://, postgres://, sqlite://, odbc://", uri ))) } @@ -169,6 +181,25 @@ mod tests { assert!(result.is_err()); } + #[test] + fn test_odbc() { + let info = + parse_connection_string("odbc://Driver=Snowflake;Server=myaccount.snowflakecomputing.com") + .unwrap(); + assert_eq!( + info, + ConnectionInfo::ODBC( + "Driver=Snowflake;Server=myaccount.snowflakecomputing.com".to_string() + ) + ); + } + + #[test] + fn test_odbc_empty() { + let result = parse_connection_string("odbc://"); + assert!(result.is_err()); + } + #[test] fn test_unsupported_scheme() { let result = parse_connection_string("mysql://localhost/db"); diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index 5f3e58ec..62c5c3e2 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -639,6 +639,10 @@ impl Reader for DuckDBReader { Ok(()) } + fn execute(&self, query: &str) -> Result { + super::execute_with_reader(self, query) + } + fn dialect(&self) -> &dyn super::SqlDialect { &DuckDbDialect } diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 214f08b6..e64c7bc4 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -94,6 +94,49 @@ pub trait SqlDialect { } } + // ========================================================================= + // Schema introspection queries (for Connections pane) + // ========================================================================= + + /// SQL to list catalog names. Returns rows with column `catalog_name`. + fn sql_list_catalogs(&self) -> String { + "SELECT DISTINCT catalog_name FROM information_schema.schemata ORDER BY catalog_name" + .into() + } + + /// SQL to list schema names within a catalog. Returns rows with column `schema_name`. + fn sql_list_schemas(&self, catalog: &str) -> String { + format!( + "SELECT DISTINCT schema_name FROM information_schema.schemata \ + WHERE catalog_name = '{}' ORDER BY schema_name", + catalog.replace('\'', "''") + ) + } + + /// SQL to list tables/views within a catalog and schema. + /// Returns rows with columns `table_name` and `table_type`. + fn sql_list_tables(&self, catalog: &str, schema: &str) -> String { + format!( + "SELECT DISTINCT table_name, table_type FROM information_schema.tables \ + WHERE table_catalog = '{}' AND table_schema = '{}' ORDER BY table_name", + catalog.replace('\'', "''"), + schema.replace('\'', "''") + ) + } + + /// SQL to list columns in a table. + /// Returns rows with columns `column_name` and `data_type`. + fn sql_list_columns(&self, catalog: &str, schema: &str, table: &str) -> String { + format!( + "SELECT column_name, data_type FROM information_schema.columns \ + WHERE table_catalog = '{}' AND table_schema = '{}' AND table_name = '{}' \ + ORDER BY ordinal_position", + catalog.replace('\'', "''"), + schema.replace('\'', "''"), + table.replace('\'', "''") + ) + } + /// Scalar MAX across any number of SQL expressions. fn sql_greatest(&self, exprs: &[&str]) -> String { let mut result = exprs[0].to_string(); @@ -177,6 +220,9 @@ pub mod polars_sql; #[cfg(feature = "sqlite")] pub mod sqlite; +#[cfg(feature = "odbc")] +pub mod odbc; + pub mod connection; pub mod data; mod spec; @@ -190,6 +236,9 @@ pub use polars_sql::PolarsReader; #[cfg(feature = "sqlite")] pub use sqlite::SqliteReader; +#[cfg(feature = "odbc")] +pub use odbc::OdbcReader; + // ============================================================================ // Spec - Result of reader.execute() // ============================================================================ @@ -334,37 +383,7 @@ pub trait Reader { /// let writer = VegaLiteWriter::new(); /// let json = writer.render(&spec)?; /// ``` - fn execute(&self, query: &str) -> Result - where - Self: Sized, - { - // Run validation first to capture warnings - let validated = validate(query)?; - let warnings: Vec = validated.warnings().to_vec(); - - // Prepare data with type names for this reader - let prepared_data = prepare_data_with_reader(query, self)?; - - // Get the first (and typically only) spec - let plot = prepared_data.specs.into_iter().next().ok_or_else(|| { - GgsqlError::ValidationError("No visualization spec found".to_string()) - })?; - - // For now, layer_sql and stat_sql are not tracked in PreparedData - // (they were part of main's version but not HEAD's) - let layer_sql = vec![None; plot.layers.len()]; - let stat_sql = vec![None; plot.layers.len()]; - - Ok(Spec::new( - plot, - prepared_data.data, - prepared_data.sql, - prepared_data.visual, - layer_sql, - stat_sql, - warnings, - )) - } + fn execute(&self, query: &str) -> Result; /// Get the SQL dialect for this reader. /// @@ -374,6 +393,35 @@ pub trait Reader { } } +/// Execute a ggsql query using any reader (object-safe entry point). +/// +/// This is the shared implementation behind `Reader::execute()`. Concrete +/// readers delegate to this so the trait stays object-safe (no `Self: Sized` +/// bound on `execute`). +pub fn execute_with_reader(reader: &dyn Reader, query: &str) -> Result { + let validated = validate(query)?; + let warnings: Vec = validated.warnings().to_vec(); + + let prepared_data = prepare_data_with_reader(query, reader)?; + + let plot = prepared_data.specs.into_iter().next().ok_or_else(|| { + GgsqlError::ValidationError("No visualization spec found".to_string()) + })?; + + let layer_sql = vec![None; plot.layers.len()]; + let stat_sql = vec![None; plot.layers.len()]; + + Ok(Spec::new( + plot, + prepared_data.data, + prepared_data.sql, + prepared_data.visual, + layer_sql, + stat_sql, + warnings, + )) +} + #[cfg(test)] #[cfg(all(feature = "duckdb", feature = "vegalite"))] mod tests { diff --git a/src/reader/odbc.rs b/src/reader/odbc.rs new file mode 100644 index 00000000..039534a1 --- /dev/null +++ b/src/reader/odbc.rs @@ -0,0 +1,319 @@ +//! Generic ODBC data source implementation +//! +//! Provides a reader for any ODBC-compatible database (Snowflake, PostgreSQL, +//! SQL Server, etc.) using the `odbc-api` crate. + +use crate::reader::Reader; +use crate::{DataFrame, GgsqlError, Result}; +use odbc_api::{buffers::TextRowSet, ConnectionOptions, Cursor, Environment}; +use polars::prelude::*; +use std::sync::OnceLock; + +/// Global ODBC environment (must be a singleton per process). +fn odbc_env() -> &'static Environment { + static ENV: OnceLock = OnceLock::new(); + ENV.get_or_init(|| { + Environment::new().expect("Failed to create ODBC environment") + }) +} + +/// ODBC SQL dialect. +/// +/// Uses ANSI SQL by default. The `variant` field can be used to detect +/// specific backends for dialect customization. +pub struct OdbcDialect { + #[allow(dead_code)] + variant: OdbcVariant, +} + +/// Detected ODBC backend variant. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum OdbcVariant { + Generic, + Snowflake, + PostgreSQL, + SqlServer, +} + +impl super::SqlDialect for OdbcDialect {} + +/// Generic ODBC reader implementing the `Reader` trait. +pub struct OdbcReader { + connection: odbc_api::Connection<'static>, + dialect: OdbcDialect, +} + +// Safety: odbc_api::Connection is Send when we ensure single-threaded access. +// The Reader trait requires &self (immutable) for execute_sql, and ODBC +// connections are safe to use from one thread at a time. +unsafe impl Send for OdbcReader {} + +impl OdbcReader { + /// Create a new ODBC reader from a `odbc://` connection URI. + /// + /// The URI format is `odbc://` followed by the raw ODBC connection string. + /// For Snowflake with Posit Workbench credentials, the reader will + /// automatically detect and inject OAuth tokens. + pub fn from_connection_string(uri: &str) -> Result { + let conn_str = uri + .strip_prefix("odbc://") + .ok_or_else(|| GgsqlError::ReaderError("ODBC URI must start with odbc://".into()))?; + + let mut conn_str = conn_str.to_string(); + + // Snowflake Workbench credential detection + if is_snowflake(&conn_str) && !has_token(&conn_str) { + if let Some(token) = detect_workbench_token() { + conn_str = inject_snowflake_token(&conn_str, &token); + } + } + + // Detect variant from connection string + let variant = detect_variant(&conn_str); + + let env = odbc_env(); + let connection = env + .connect_with_connection_string(&conn_str, ConnectionOptions::default()) + .map_err(|e| GgsqlError::ReaderError(format!("ODBC connection failed: {}", e)))?; + + Ok(Self { + connection, + dialect: OdbcDialect { variant }, + }) + } +} + +impl Reader for OdbcReader { + fn execute_sql(&self, sql: &str) -> Result { + // Execute the query (3rd arg = query timeout, None = no timeout) + let cursor = self + .connection + .execute(sql, (), None) + .map_err(|e| GgsqlError::ReaderError(format!("ODBC execute failed: {}", e)))?; + + let Some(cursor) = cursor else { + // DDL or non-query statement — return empty DataFrame + return DataFrame::new(Vec::::new()) + .map_err(|e| GgsqlError::ReaderError(format!("Empty DataFrame error: {}", e))); + }; + + cursor_to_dataframe(cursor) + } + + fn register(&self, name: &str, _df: DataFrame, _replace: bool) -> Result<()> { + Err(GgsqlError::ReaderError(format!( + "ODBC reader does not support registering in-memory tables (attempted: '{}')", + name + ))) + } + + fn execute(&self, query: &str) -> Result { + super::execute_with_reader(self, query) + } + + fn dialect(&self) -> &dyn super::SqlDialect { + &self.dialect + } +} + +/// Convert an ODBC cursor to a Polars DataFrame by fetching all rows as text. +fn cursor_to_dataframe(mut cursor: impl Cursor) -> Result { + let col_count = cursor.num_result_cols() + .map_err(|e| GgsqlError::ReaderError(format!("Failed to get column count: {}", e)))? + as usize; + + if col_count == 0 { + return DataFrame::new(Vec::::new()) + .map_err(|e| GgsqlError::ReaderError(e.to_string())); + } + + // Collect column names + let mut col_names = Vec::with_capacity(col_count); + for i in 1..=col_count as u16 { + let name = cursor + .col_name(i) + .map_err(|e| GgsqlError::ReaderError(format!("Failed to get column {} name: {}", i, e)))?; + col_names.push(name); + } + + // Fetch all rows as text into column-oriented vectors + let batch_size = 1000; + let max_str_len = 4096; + let mut columns: Vec>> = vec![Vec::new(); col_count]; + + let mut row_set = TextRowSet::for_cursor(batch_size, &mut cursor, Some(max_str_len)) + .map_err(|e| GgsqlError::ReaderError(format!("Failed to create row set: {}", e)))?; + + let mut block_cursor = cursor.bind_buffer(&mut row_set) + .map_err(|e| GgsqlError::ReaderError(format!("Failed to bind buffer: {}", e)))?; + + while let Some(batch) = block_cursor.fetch() + .map_err(|e| GgsqlError::ReaderError(format!("Failed to fetch batch: {}", e)))? + { + let num_rows = batch.num_rows(); + for col_idx in 0..col_count { + for row_idx in 0..num_rows { + let value = batch + .at_as_str(col_idx, row_idx) + .ok() + .flatten() + .map(|s| s.to_string()); + columns[col_idx].push(value); + } + } + } + + // Build Polars Series from the text data, attempting type inference + let series: Vec = col_names + .iter() + .zip(columns.iter()) + .map(|(name, values)| { + // Try to parse as numeric first, then fall back to string + let series = if let Some(int_series) = try_parse_integers(name, values) { + int_series + } else if let Some(float_series) = try_parse_floats(name, values) { + float_series + } else { + // Fall back to string + Series::new( + name.into(), + values + .iter() + .map(|v| v.as_deref()) + .collect::>>(), + ) + }; + Column::from(series) + }) + .collect(); + + DataFrame::new(series).map_err(|e| GgsqlError::ReaderError(e.to_string())) +} + +/// Try to parse all non-null values as i64. +fn try_parse_integers(name: &str, values: &[Option]) -> Option { + let parsed: Vec> = values + .iter() + .map(|v| match v { + None => Some(None), + Some(s) => s.parse::().ok().map(Some), + }) + .collect::>>()?; + Some(Series::new(name.into(), parsed)) +} + +/// Try to parse all non-null values as f64. +fn try_parse_floats(name: &str, values: &[Option]) -> Option { + let parsed: Vec> = values + .iter() + .map(|v| match v { + None => Some(None), + Some(s) => s.parse::().ok().map(Some), + }) + .collect::>>()?; + Some(Series::new(name.into(), parsed)) +} + +// ============================================================================ +// Snowflake Workbench credential detection +// ============================================================================ + +fn is_snowflake(conn_str: &str) -> bool { + conn_str.to_lowercase().contains("driver=snowflake") +} + +fn has_token(conn_str: &str) -> bool { + conn_str.to_lowercase().contains("token=") +} + +fn detect_variant(conn_str: &str) -> OdbcVariant { + let lower = conn_str.to_lowercase(); + if lower.contains("driver=snowflake") { + OdbcVariant::Snowflake + } else if lower.contains("driver={postgresql}") || lower.contains("driver=postgresql") { + OdbcVariant::PostgreSQL + } else if lower.contains("driver={odbc driver") || lower.contains("driver={sql server") { + OdbcVariant::SqlServer + } else { + OdbcVariant::Generic + } +} + +/// Detect Posit Workbench Snowflake OAuth token. +/// +/// Checks `SNOWFLAKE_HOME` for a Workbench-managed `connections.toml` file +/// containing OAuth credentials. +fn detect_workbench_token() -> Option { + let snowflake_home = std::env::var("SNOWFLAKE_HOME").ok()?; + + // Only use Workbench credentials if the path indicates Workbench management + if !snowflake_home.contains("posit-workbench") { + return None; + } + + let toml_path = std::path::Path::new(&snowflake_home).join("connections.toml"); + let content = std::fs::read_to_string(&toml_path).ok()?; + + let doc = content.parse::().ok()?; + let token = doc + .get("workbench")? + .get("token")? + .as_str()? + .to_string(); + + if token.is_empty() { + None + } else { + Some(token) + } +} + +/// Inject OAuth token into a Snowflake ODBC connection string. +fn inject_snowflake_token(conn_str: &str, token: &str) -> String { + // Append authenticator and token parameters + let mut result = conn_str.trim_end_matches(';').to_string(); + result.push_str(";Authenticator=oauth;Token="); + result.push_str(token); + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_snowflake() { + assert!(is_snowflake("Driver=Snowflake;Server=foo.snowflakecomputing.com")); + assert!(!is_snowflake("Driver={PostgreSQL};Server=localhost")); + } + + #[test] + fn test_has_token() { + assert!(has_token("Driver=Snowflake;Token=abc123")); + assert!(!has_token("Driver=Snowflake;Server=foo")); + } + + #[test] + fn test_detect_variant() { + assert_eq!( + detect_variant("Driver=Snowflake;Server=foo"), + OdbcVariant::Snowflake + ); + assert_eq!( + detect_variant("Driver={PostgreSQL};Server=localhost"), + OdbcVariant::PostgreSQL + ); + assert_eq!( + detect_variant("Driver=SomeOther;Server=localhost"), + OdbcVariant::Generic + ); + } + + #[test] + fn test_inject_snowflake_token() { + let result = + inject_snowflake_token("Driver=Snowflake;Server=foo.snowflakecomputing.com", "mytoken"); + assert!(result.contains("Authenticator=oauth")); + assert!(result.contains("Token=mytoken")); + } +} diff --git a/src/reader/polars_sql.rs b/src/reader/polars_sql.rs index 4bf69868..33b039ad 100644 --- a/src/reader/polars_sql.rs +++ b/src/reader/polars_sql.rs @@ -204,6 +204,10 @@ impl Reader for PolarsReader { Ok(()) } + fn execute(&self, query: &str) -> Result { + super::execute_with_reader(self, query) + } + fn unregister(&self, name: &str) -> Result<()> { // Only allow unregistering tables we created via register() if !self.registered_tables.borrow().contains(name) { diff --git a/src/reader/sqlite.rs b/src/reader/sqlite.rs index 09d8b015..81dd1c17 100644 --- a/src/reader/sqlite.rs +++ b/src/reader/sqlite.rs @@ -45,6 +45,29 @@ impl super::SqlDialect for SqliteDialect { fn time_type_name(&self) -> Option<&str> { Some("TEXT") } + + fn sql_list_catalogs(&self) -> String { + "SELECT name AS catalog_name FROM pragma_database_list ORDER BY name".into() + } + + fn sql_list_schemas(&self, _catalog: &str) -> String { + "SELECT 'main' AS schema_name".into() + } + + fn sql_list_tables(&self, catalog: &str, _schema: &str) -> String { + format!( + "SELECT name AS table_name, type AS table_type FROM \"{}\".sqlite_master \ + WHERE type IN ('table', 'view') ORDER BY name", + catalog.replace('"', "\"\"") + ) + } + + fn sql_list_columns(&self, _catalog: &str, _schema: &str, table: &str) -> String { + format!( + "SELECT name AS column_name, type AS data_type FROM pragma_table_info('{}') ORDER BY cid", + table.replace('\'', "''") + ) + } } /// SQLite database reader @@ -420,6 +443,10 @@ impl Reader for SqliteReader { Ok(()) } + fn execute(&self, query: &str) -> Result { + super::execute_with_reader(self, query) + } + fn dialect(&self) -> &dyn super::SqlDialect { &SqliteDialect } From 8dc3310430a35ffafbe202f1f33645c61c18b9e4 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Wed, 18 Mar 2026 10:34:40 +0000 Subject: [PATCH 02/20] WIP: ODBC --- ggsql-vscode/package-lock.json | 9 + ggsql-vscode/package.json | 3 + ggsql-vscode/src/connections.ts | 334 ++++++++++++++++++++++---- src/Cargo.toml | 1 + src/reader/duckdb.rs | 29 +-- src/reader/mod.rs | 36 +++ src/reader/odbc.rs | 412 ++++++++++++++++++++++++++++++-- 7 files changed, 730 insertions(+), 94 deletions(-) diff --git a/ggsql-vscode/package-lock.json b/ggsql-vscode/package-lock.json index ac31ff6c..2153588b 100644 --- a/ggsql-vscode/package-lock.json +++ b/ggsql-vscode/package-lock.json @@ -8,6 +8,9 @@ "name": "ggsql", "version": "0.1.8", "license": "MIT", + "dependencies": { + "toml": "^3.0.0" + }, "devDependencies": { "@posit-dev/positron": "^0.2.2", "@types/node": "^18.x", @@ -3963,6 +3966,12 @@ "node": ">=8.0" } }, + "node_modules/toml": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/toml/-/toml-3.0.0.tgz", + "integrity": "sha512-y/mWCZinnvxjTKYhJ+pYxwD0mRLVvOtdS2Awbgxln6iEnt4rk0yBxeSBHkGJcPucRiG0e55mwWp+g/05rsrd6w==", + "license": "MIT" + }, "node_modules/ts-api-utils": { "version": "1.4.3", "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-1.4.3.tgz", diff --git a/ggsql-vscode/package.json b/ggsql-vscode/package.json index 57a0ded1..db0d3a22 100644 --- a/ggsql-vscode/package.json +++ b/ggsql-vscode/package.json @@ -76,6 +76,9 @@ "check-types": "tsc --noEmit", "lint": "eslint src --ext ts" }, + "dependencies": { + "toml": "^3.0.0" + }, "devDependencies": { "@posit-dev/positron": "^0.2.2", "@types/node": "^18.x", diff --git a/ggsql-vscode/src/connections.ts b/ggsql-vscode/src/connections.ts index 9c2b7e11..0df0641f 100644 --- a/ggsql-vscode/src/connections.ts +++ b/ggsql-vscode/src/connections.ts @@ -6,9 +6,14 @@ * that the ggsql-jupyter kernel interprets to switch readers. */ +import * as os from 'os'; +import * as path from 'path'; +import * as fs from 'fs'; +import * as toml from 'toml'; import type * as positron from '@posit-dev/positron'; type PositronApi = positron.PositronApi; +type ConnectionsDriverMetadata = positron.ConnectionsDriverMetadata & { description?: string }; /** * Create the set of ggsql connection drivers to register with Positron. @@ -18,11 +23,18 @@ export function createConnectionDrivers( ): positron.ConnectionsDriver[] { return [ createDuckDBDriver(positronApi), - createSnowflakeDriver(positronApi), + createSnowflakeDefaultDriver(positronApi), + createSnowflakePasswordDriver(positronApi), + createSnowflakeSSODriver(positronApi), + createSnowflakePATDriver(positronApi), createOdbcDriver(positronApi), ]; } +// ============================================================================ +// DuckDB +// ============================================================================ + /** * DuckDB connection driver. * @@ -58,66 +70,300 @@ function createDuckDBDriver( }; } +// ============================================================================ +// Snowflake — shared helpers +// ============================================================================ + +interface SnowflakeConnectionEntry { + name: string; + account?: string; +} + /** - * Snowflake connection driver (via ODBC). - * - * Builds an ODBC connection string targeting the Snowflake ODBC driver. - * Workbench OAuth token injection happens automatically in the kernel. + * Find the Snowflake connections.toml file, checking standard locations. */ -function createSnowflakeDriver( +function findSnowflakeConnectionsToml(): string | undefined { + const candidates: string[] = []; + + // 1. $SNOWFLAKE_HOME/connections.toml + const snowflakeHome = process.env.SNOWFLAKE_HOME; + if (snowflakeHome) { + candidates.push(path.join(snowflakeHome, 'connections.toml')); + } + + // 2. ~/.snowflake/connections.toml + const home = os.homedir(); + candidates.push(path.join(home, '.snowflake', 'connections.toml')); + + // 3. Platform-specific paths + if (process.platform === 'darwin') { + candidates.push( + path.join(home, 'Library', 'Application Support', 'snowflake', 'connections.toml') + ); + } else if (process.platform === 'linux') { + const xdgConfig = process.env.XDG_CONFIG_HOME || path.join(home, '.config'); + candidates.push(path.join(xdgConfig, 'snowflake', 'connections.toml')); + } else if (process.platform === 'win32') { + candidates.push( + path.join(home, 'AppData', 'Local', 'snowflake', 'connections.toml') + ); + } + + for (const candidate of candidates) { + if (fs.existsSync(candidate)) { + return candidate; + } + } + return undefined; +} + +/** + * Read Snowflake connection entries from connections.toml. + */ +function readSnowflakeConnections(): { + connections: SnowflakeConnectionEntry[]; + defaultConnection?: string; +} { + const tomlPath = findSnowflakeConnectionsToml(); + if (!tomlPath) { + return { connections: [] }; + } + + try { + const content = fs.readFileSync(tomlPath, 'utf-8'); + const parsed = toml.parse(content); + + const defaultConnection = + process.env.SNOWFLAKE_DEFAULT_CONNECTION_NAME || + parsed.default_connection_name || + undefined; + + const connections: SnowflakeConnectionEntry[] = Object.keys(parsed) + .filter( + (key) => + key !== 'default_connection_name' && + typeof parsed[key] === 'object' && + parsed[key] !== null + ) + .map((name) => ({ + name, + account: parsed[name].account as string | undefined, + })); + + return { connections, defaultConnection }; + } catch { + return { connections: [] }; + } +} + +/** + * Build an ODBC connection string for Snowflake with the given parts. + */ +function buildSnowflakeOdbc(parts: Record): string { + let connStr = `Driver=Snowflake;Server=${parts.account}.snowflakecomputing.com`; + if (parts.uid) { + connStr += `;UID=${parts.uid}`; + } + if (parts.pwd) { + connStr += `;PWD=${parts.pwd}`; + } + if (parts.authenticator) { + connStr += `;Authenticator=${parts.authenticator}`; + } + if (parts.token) { + connStr += `;Token=${parts.token}`; + } + if (parts.warehouse) { + connStr += `;Warehouse=${parts.warehouse}`; + } + if (parts.database) { + connStr += `;Database=${parts.database}`; + } + if (parts.schema) { + connStr += `;Schema=${parts.schema}`; + } + return `-- @connect: odbc://${connStr}`; +} + +function snowflakeConnect(positronApi: PositronApi) { + return async (code: string) => { + await positronApi.runtime.executeCode('ggsql', code, false); + }; +} + +// ============================================================================ +// Snowflake — Default Connection (connections.toml) +// ============================================================================ + +function createSnowflakeDefaultDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + const { connections, defaultConnection } = readSnowflakeConnections(); + + let inputs: positron.ConnectionsInput[]; + if (connections.length > 0) { + const defaultValue = + defaultConnection || + (connections.find((c) => c.name === 'default')?.name ?? connections[0].name); + + inputs = [ + { + id: 'connection_name', + label: 'Connection Name', + type: 'option', + options: connections.map((conn) => ({ + identifier: conn.name, + title: conn.account + ? `${conn.name} (${conn.account})` + : conn.name, + })), + value: defaultValue, + }, + ]; + } else { + inputs = [ + { + id: 'connection_name', + label: 'Connection Name', + type: 'string', + value: 'default', + }, + ]; + } + + return { + driverId: 'ggsql-snowflake-default', + metadata: { + languageId: 'ggsql', + name: 'Snowflake', + description: 'Default Connection (connections.toml)', + inputs, + } as ConnectionsDriverMetadata, + generateCode: (inputs) => { + const name = + inputs.find((i) => i.id === 'connection_name')?.value?.trim() || 'default'; + return `-- @connect: odbc://Driver=Snowflake;ConnectionName=${name}`; + }, + connect: snowflakeConnect(positronApi), + }; +} + +// ============================================================================ +// Snowflake — Username/Password +// ============================================================================ + +function createSnowflakePasswordDriver( positronApi: PositronApi ): positron.ConnectionsDriver { return { - driverId: 'ggsql-snowflake', + driverId: 'ggsql-snowflake-password', metadata: { languageId: 'ggsql', name: 'Snowflake', + description: 'Username/Password', inputs: [ - { - id: 'account', - label: 'Account', - type: 'string', - }, - { - id: 'warehouse', - label: 'Warehouse', - type: 'string', - }, - { - id: 'database', - label: 'Database', - type: 'string', - value: '', - }, - { - id: 'schema', - label: 'Schema', - type: 'string', - value: '', - }, + { id: 'account', label: 'Account', type: 'string' }, + { id: 'user', label: 'User', type: 'string' }, + { id: 'password', label: 'Password', type: 'string' }, + { id: 'warehouse', label: 'Warehouse', type: 'string' }, + { id: 'database', label: 'Database', type: 'string', value: '' }, + { id: 'schema', label: 'Schema', type: 'string', value: '' }, ], + } as ConnectionsDriverMetadata, + generateCode: (inputs) => { + const get = (id: string) => + inputs.find((i) => i.id === id)?.value?.trim() || ''; + return buildSnowflakeOdbc({ + account: get('account'), + uid: get('user'), + pwd: get('password'), + warehouse: get('warehouse'), + database: get('database') || undefined, + schema: get('schema') || undefined, + }); }, + connect: snowflakeConnect(positronApi), + }; +} + +// ============================================================================ +// Snowflake — External Browser (SSO) +// ============================================================================ + +function createSnowflakeSSODriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-snowflake-sso', + metadata: { + languageId: 'ggsql', + name: 'Snowflake', + description: 'External Browser (SSO)', + inputs: [ + { id: 'account', label: 'Account', type: 'string' }, + { id: 'user', label: 'User', type: 'string', value: '' }, + { id: 'warehouse', label: 'Warehouse', type: 'string' }, + { id: 'database', label: 'Database', type: 'string', value: '' }, + { id: 'schema', label: 'Schema', type: 'string', value: '' }, + ], + } as ConnectionsDriverMetadata, generateCode: (inputs) => { - const account = inputs.find((i) => i.id === 'account')?.value ?? ''; - const warehouse = inputs.find((i) => i.id === 'warehouse')?.value ?? ''; - const database = inputs.find((i) => i.id === 'database')?.value ?? ''; - const schema = inputs.find((i) => i.id === 'schema')?.value ?? ''; - - let connStr = `Driver=Snowflake;Server=${account}.snowflakecomputing.com;Warehouse=${warehouse}`; - if (database) { - connStr += `;Database=${database}`; - } - if (schema) { - connStr += `;Schema=${schema}`; - } - return `-- @connect: odbc://${connStr}`; + const get = (id: string) => + inputs.find((i) => i.id === id)?.value?.trim() || ''; + return buildSnowflakeOdbc({ + account: get('account'), + uid: get('user') || undefined, + authenticator: 'externalbrowser', + warehouse: get('warehouse'), + database: get('database') || undefined, + schema: get('schema') || undefined, + }); }, - connect: async (code: string) => { - await positronApi.runtime.executeCode('ggsql', code, false); + connect: snowflakeConnect(positronApi), + }; +} + +// ============================================================================ +// Snowflake — Programmatic Access Token (PAT) +// ============================================================================ + +function createSnowflakePATDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-snowflake-pat', + metadata: { + languageId: 'ggsql', + name: 'Snowflake', + description: 'Programmatic Access Token (PAT)', + inputs: [ + { id: 'account', label: 'Account', type: 'string' }, + { id: 'token', label: 'Token', type: 'string' }, + { id: 'warehouse', label: 'Warehouse', type: 'string' }, + { id: 'database', label: 'Database', type: 'string', value: '' }, + { id: 'schema', label: 'Schema', type: 'string', value: '' }, + ], + } as ConnectionsDriverMetadata, + generateCode: (inputs) => { + const get = (id: string) => + inputs.find((i) => i.id === id)?.value?.trim() || ''; + return buildSnowflakeOdbc({ + account: get('account'), + authenticator: 'programmatic_access_token', + token: get('token'), + warehouse: get('warehouse'), + database: get('database') || undefined, + schema: get('schema') || undefined, + }); }, + connect: snowflakeConnect(positronApi), }; } +// ============================================================================ +// Generic ODBC +// ============================================================================ + /** * Generic ODBC connection driver. * diff --git a/src/Cargo.toml b/src/Cargo.toml index 64ad22f1..284917ff 100644 --- a/src/Cargo.toml +++ b/src/Cargo.toml @@ -75,6 +75,7 @@ pyo3 = { workspace = true, optional = true } [dev-dependencies] jsonschema = "0.44" proptest.workspace = true +tempfile = "3.8" ureq = "3" [features] diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index 62c5c3e2..3a0cf616 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -144,34 +144,7 @@ impl DuckDBReader { } } -/// Validate a table name -fn validate_table_name(name: &str) -> Result<()> { - if name.is_empty() { - return Err(GgsqlError::ReaderError("Table name cannot be empty".into())); - } - - // Reject characters that could break double-quoted identifiers or cause issues - let forbidden = ['"', '\0', '\n', '\r']; - for ch in forbidden { - if name.contains(ch) { - return Err(GgsqlError::ReaderError(format!( - "Table name '{}' contains invalid character '{}'", - name, - ch.escape_default() - ))); - } - } - - // Reasonable length limit - if name.len() > 128 { - return Err(GgsqlError::ReaderError(format!( - "Table name '{}' exceeds maximum length of 128 characters", - name - ))); - } - - Ok(()) -} +use super::validate_table_name; /// Convert a Polars DataFrame to DuckDB Arrow query parameters via IPC serialization fn dataframe_to_arrow_params(df: DataFrame) -> Result<[usize; 2]> { diff --git a/src/reader/mod.rs b/src/reader/mod.rs index e64c7bc4..02a79c63 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -239,6 +239,42 @@ pub use sqlite::SqliteReader; #[cfg(feature = "odbc")] pub use odbc::OdbcReader; +// ============================================================================ +// Shared utilities +// ============================================================================ + +/// Validate a table name for use in SQL statements. +/// +/// Rejects empty names, names with characters that could break double-quoted +/// identifiers, and names exceeding 128 characters. +pub(crate) fn validate_table_name(name: &str) -> Result<()> { + if name.is_empty() { + return Err(GgsqlError::ReaderError("Table name cannot be empty".into())); + } + + // Reject characters that could break double-quoted identifiers or cause issues + let forbidden = ['"', '\0', '\n', '\r']; + for ch in forbidden { + if name.contains(ch) { + return Err(GgsqlError::ReaderError(format!( + "Table name '{}' contains invalid character '{}'", + name, + ch.escape_default() + ))); + } + } + + // Reasonable length limit + if name.len() > 128 { + return Err(GgsqlError::ReaderError(format!( + "Table name '{}' exceeds maximum length of 128 characters", + name + ))); + } + + Ok(()) +} + // ============================================================================ // Spec - Result of reader.execute() // ============================================================================ diff --git a/src/reader/odbc.rs b/src/reader/odbc.rs index 039534a1..bab8a0b2 100644 --- a/src/reader/odbc.rs +++ b/src/reader/odbc.rs @@ -7,14 +7,14 @@ use crate::reader::Reader; use crate::{DataFrame, GgsqlError, Result}; use odbc_api::{buffers::TextRowSet, ConnectionOptions, Cursor, Environment}; use polars::prelude::*; +use std::cell::RefCell; +use std::collections::HashSet; use std::sync::OnceLock; /// Global ODBC environment (must be a singleton per process). fn odbc_env() -> &'static Environment { static ENV: OnceLock = OnceLock::new(); - ENV.get_or_init(|| { - Environment::new().expect("Failed to create ODBC environment") - }) + ENV.get_or_init(|| Environment::new().expect("Failed to create ODBC environment")) } /// ODBC SQL dialect. @@ -41,6 +41,7 @@ impl super::SqlDialect for OdbcDialect {} pub struct OdbcReader { connection: odbc_api::Connection<'static>, dialect: OdbcDialect, + registered_tables: RefCell>, } // Safety: odbc_api::Connection is Send when we ensure single-threaded access. @@ -61,6 +62,13 @@ impl OdbcReader { let mut conn_str = conn_str.to_string(); + // Snowflake ConnectionName resolution from connections.toml + if is_snowflake(&conn_str) { + if let Some(resolved) = resolve_connection_name(&conn_str) { + conn_str = resolved; + } + } + // Snowflake Workbench credential detection if is_snowflake(&conn_str) && !has_token(&conn_str) { if let Some(token) = detect_workbench_token() { @@ -79,6 +87,7 @@ impl OdbcReader { Ok(Self { connection, dialect: OdbcDialect { variant }, + registered_tables: RefCell::new(HashSet::new()), }) } } @@ -100,11 +109,145 @@ impl Reader for OdbcReader { cursor_to_dataframe(cursor) } - fn register(&self, name: &str, _df: DataFrame, _replace: bool) -> Result<()> { - Err(GgsqlError::ReaderError(format!( - "ODBC reader does not support registering in-memory tables (attempted: '{}')", - name - ))) + fn register(&self, name: &str, df: DataFrame, replace: bool) -> Result<()> { + super::validate_table_name(name)?; + + if replace { + let drop_sql = format!("DROP TABLE IF EXISTS \"{}\"", name); + // Ignore errors from DROP — table may not exist + let _ = self.connection.execute(&drop_sql, (), None); + } + + // Build CREATE TEMP TABLE with typed columns + let schema = df.schema(); + let col_defs: Vec = schema + .iter() + .map(|(col_name, dtype)| format!("\"{}\" {}", col_name, polars_dtype_to_sql(dtype))) + .collect(); + let create_sql = format!( + "CREATE TEMPORARY TABLE \"{}\" ({})", + name, + col_defs.join(", ") + ); + self.connection + .execute(&create_sql, (), None) + .map_err(|e| { + GgsqlError::ReaderError(format!("Failed to create temp table '{}': {}", name, e)) + })?; + + // Insert data using ODBC bulk text inserter + let num_rows = df.height(); + if num_rows > 0 { + let num_cols = df.width(); + let placeholders: Vec<&str> = vec!["?"; num_cols]; + let insert_sql = format!( + "INSERT INTO \"{}\" VALUES ({})", + name, + placeholders.join(", ") + ); + + // Convert all columns to string representation for text insertion + let string_columns: Vec>> = df + .get_columns() + .iter() + .map(|col| { + (0..num_rows) + .map(|row| { + let val = col.get(row).ok()?; + if val == AnyValue::Null { + None + } else { + Some(format!("{}", val)) + } + }) + .collect() + }) + .collect(); + + // Determine max string length per column for buffer allocation + let max_str_lens: Vec = string_columns + .iter() + .map(|col| { + col.iter() + .filter_map(|v| v.as_ref().map(|s| s.len())) + .max() + .unwrap_or(1) + .max(1) // minimum buffer size of 1 + }) + .collect(); + + const BATCH_SIZE: usize = 1024; + let prepared = self.connection.prepare(&insert_sql).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to prepare INSERT for '{}': {}", name, e)) + })?; + + let batch_capacity = num_rows.min(BATCH_SIZE); + let mut inserter = prepared + .into_text_inserter(batch_capacity, max_str_lens) + .map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to create bulk inserter for '{}': {}", + name, e + )) + })?; + + let mut rows_in_batch = 0; + for row_idx in 0..num_rows { + let row_values: Vec> = string_columns + .iter() + .map(|col| col[row_idx].as_ref().map(|s| s.as_bytes())) + .collect(); + + inserter.append(row_values.into_iter()).map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to append row {} to '{}': {}", + row_idx, name, e + )) + })?; + rows_in_batch += 1; + + if rows_in_batch >= BATCH_SIZE { + inserter.execute().map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to execute batch insert into '{}': {}", + name, e + )) + })?; + inserter.clear(); + rows_in_batch = 0; + } + } + + // Execute final partial batch + if rows_in_batch > 0 { + inserter.execute().map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to execute final batch insert into '{}': {}", + name, e + )) + })?; + } + } + + self.registered_tables.borrow_mut().insert(name.to_string()); + Ok(()) + } + + fn unregister(&self, name: &str) -> Result<()> { + if !self.registered_tables.borrow().contains(name) { + return Err(GgsqlError::ReaderError(format!( + "Table '{}' was not registered via this reader", + name + ))); + } + + let sql = format!("DROP TABLE IF EXISTS \"{}\"", name); + self.connection.execute(&sql, (), None).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to unregister table '{}': {}", name, e)) + })?; + + self.registered_tables.borrow_mut().remove(name); + Ok(()) } fn execute(&self, query: &str) -> Result { @@ -116,9 +259,24 @@ impl Reader for OdbcReader { } } +/// Map a Polars data type to a SQL column type string. +fn polars_dtype_to_sql(dtype: &DataType) -> &'static str { + match dtype { + DataType::Boolean => "BOOLEAN", + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => "BIGINT", + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => "BIGINT", + DataType::Float32 | DataType::Float64 => "DOUBLE PRECISION", + DataType::Date => "DATE", + DataType::Datetime(_, _) => "TIMESTAMP", + DataType::Time => "TIME", + _ => "TEXT", + } +} + /// Convert an ODBC cursor to a Polars DataFrame by fetching all rows as text. fn cursor_to_dataframe(mut cursor: impl Cursor) -> Result { - let col_count = cursor.num_result_cols() + let col_count = cursor + .num_result_cols() .map_err(|e| GgsqlError::ReaderError(format!("Failed to get column count: {}", e)))? as usize; @@ -130,9 +288,9 @@ fn cursor_to_dataframe(mut cursor: impl Cursor) -> Result { // Collect column names let mut col_names = Vec::with_capacity(col_count); for i in 1..=col_count as u16 { - let name = cursor - .col_name(i) - .map_err(|e| GgsqlError::ReaderError(format!("Failed to get column {} name: {}", i, e)))?; + let name = cursor.col_name(i).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to get column {} name: {}", i, e)) + })?; col_names.push(name); } @@ -144,10 +302,12 @@ fn cursor_to_dataframe(mut cursor: impl Cursor) -> Result { let mut row_set = TextRowSet::for_cursor(batch_size, &mut cursor, Some(max_str_len)) .map_err(|e| GgsqlError::ReaderError(format!("Failed to create row set: {}", e)))?; - let mut block_cursor = cursor.bind_buffer(&mut row_set) + let mut block_cursor = cursor + .bind_buffer(&mut row_set) .map_err(|e| GgsqlError::ReaderError(format!("Failed to bind buffer: {}", e)))?; - while let Some(batch) = block_cursor.fetch() + while let Some(batch) = block_cursor + .fetch() .map_err(|e| GgsqlError::ReaderError(format!("Failed to fetch batch: {}", e)))? { let num_rows = batch.num_rows(); @@ -239,6 +399,139 @@ fn detect_variant(conn_str: &str) -> OdbcVariant { } } +fn home_dir() -> Option { + #[cfg(target_os = "windows")] + { + std::env::var("USERPROFILE").ok().map(std::path::PathBuf::from) + } + #[cfg(not(target_os = "windows"))] + { + std::env::var("HOME").ok().map(std::path::PathBuf::from) + } +} + +/// Find the Snowflake connections.toml file, checking standard locations. +fn find_snowflake_connections_toml() -> Option { + use std::path::PathBuf; + + // 1. $SNOWFLAKE_HOME/connections.toml + if let Ok(snowflake_home) = std::env::var("SNOWFLAKE_HOME") { + let p = PathBuf::from(&snowflake_home).join("connections.toml"); + if p.exists() { + return Some(p); + } + } + + // 2. ~/.snowflake/connections.toml + if let Some(home) = home_dir() { + let p = home.join(".snowflake").join("connections.toml"); + if p.exists() { + return Some(p); + } + } + + // 3. Platform-specific paths + if let Some(home) = home_dir() { + #[cfg(target_os = "macos")] + { + let p = home + .join("Library/Application Support/snowflake/connections.toml"); + if p.exists() { + return Some(p); + } + } + + #[cfg(target_os = "linux")] + { + let xdg = std::env::var("XDG_CONFIG_HOME") + .map(PathBuf::from) + .unwrap_or_else(|_| home.join(".config")); + let p = xdg.join("snowflake").join("connections.toml"); + if p.exists() { + return Some(p); + } + } + + #[cfg(target_os = "windows")] + { + let p = home + .join("AppData/Local/snowflake/connections.toml"); + if p.exists() { + return Some(p); + } + } + } + + None +} + +/// Resolve a `ConnectionName=` parameter in a Snowflake ODBC connection +/// string by reading the named entry from `~/.snowflake/connections.toml` and +/// building a full ODBC connection string from it. +fn resolve_connection_name(conn_str: &str) -> Option { + // Extract ConnectionName value (case-insensitive) + let lower = conn_str.to_lowercase(); + let cn_key = "connectionname="; + let cn_start = lower.find(cn_key)?; + let value_start = cn_start + cn_key.len(); + + let rest = &conn_str[value_start..]; + let value_end = rest.find(';').unwrap_or(rest.len()); + let connection_name = rest[..value_end].trim(); + + if connection_name.is_empty() { + return None; + } + + // Read and parse connections.toml + let toml_path = find_snowflake_connections_toml()?; + let content = std::fs::read_to_string(&toml_path).ok()?; + let doc = content.parse::().ok()?; + + let entry = doc.get(connection_name)?; + if !entry.is_table() && !entry.is_inline_table() { + return None; + } + + // Build ODBC connection string from TOML entry fields + let get_str = |key: &str| -> Option { + entry.get(key)?.as_str().map(|s| s.to_string()) + }; + + let account = get_str("account")?; + let mut parts = vec![ + "Driver=Snowflake".to_string(), + format!("Server={}.snowflakecomputing.com", account), + ]; + + if let Some(user) = get_str("user") { + parts.push(format!("UID={}", user)); + } + if let Some(password) = get_str("password") { + parts.push(format!("PWD={}", password)); + } + if let Some(authenticator) = get_str("authenticator") { + parts.push(format!("Authenticator={}", authenticator)); + } + if let Some(token) = get_str("token") { + parts.push(format!("Token={}", token)); + } + if let Some(warehouse) = get_str("warehouse") { + parts.push(format!("Warehouse={}", warehouse)); + } + if let Some(database) = get_str("database") { + parts.push(format!("Database={}", database)); + } + if let Some(schema) = get_str("schema") { + parts.push(format!("Schema={}", schema)); + } + if let Some(role) = get_str("role") { + parts.push(format!("Role={}", role)); + } + + Some(parts.join(";")) +} + /// Detect Posit Workbench Snowflake OAuth token. /// /// Checks `SNOWFLAKE_HOME` for a Workbench-managed `connections.toml` file @@ -255,11 +548,7 @@ fn detect_workbench_token() -> Option { let content = std::fs::read_to_string(&toml_path).ok()?; let doc = content.parse::().ok()?; - let token = doc - .get("workbench")? - .get("token")? - .as_str()? - .to_string(); + let token = doc.get("workbench")?.get("token")?.as_str()?.to_string(); if token.is_empty() { None @@ -283,7 +572,9 @@ mod tests { #[test] fn test_is_snowflake() { - assert!(is_snowflake("Driver=Snowflake;Server=foo.snowflakecomputing.com")); + assert!(is_snowflake( + "Driver=Snowflake;Server=foo.snowflakecomputing.com" + )); assert!(!is_snowflake("Driver={PostgreSQL};Server=localhost")); } @@ -311,9 +602,86 @@ mod tests { #[test] fn test_inject_snowflake_token() { - let result = - inject_snowflake_token("Driver=Snowflake;Server=foo.snowflakecomputing.com", "mytoken"); + let result = inject_snowflake_token( + "Driver=Snowflake;Server=foo.snowflakecomputing.com", + "mytoken", + ); assert!(result.contains("Authenticator=oauth")); assert!(result.contains("Token=mytoken")); } + + #[test] + fn test_resolve_connection_name_with_toml() { + use std::io::Write; + + // Create a temp dir with a connections.toml + let dir = tempfile::tempdir().unwrap(); + let toml_path = dir.path().join("connections.toml"); + let mut f = std::fs::File::create(&toml_path).unwrap(); + writeln!( + f, + r#" +default_connection_name = "myconn" + +[myconn] +account = "myaccount" +user = "myuser" +password = "mypass" +warehouse = "mywh" +database = "mydb" +schema = "public" +role = "myrole" + +[other] +account = "otheraccount" +"# + ) + .unwrap(); + + // Point SNOWFLAKE_HOME at our temp dir + std::env::set_var("SNOWFLAKE_HOME", dir.path()); + + let result = + resolve_connection_name("Driver=Snowflake;ConnectionName=myconn"); + assert!(result.is_some()); + let conn = result.unwrap(); + assert!(conn.contains("Driver=Snowflake")); + assert!(conn.contains("Server=myaccount.snowflakecomputing.com")); + assert!(conn.contains("UID=myuser")); + assert!(conn.contains("PWD=mypass")); + assert!(conn.contains("Warehouse=mywh")); + assert!(conn.contains("Database=mydb")); + assert!(conn.contains("Schema=public")); + assert!(conn.contains("Role=myrole")); + + // Test with a connection that has fewer fields + let result2 = + resolve_connection_name("Driver=Snowflake;ConnectionName=other"); + assert!(result2.is_some()); + let conn2 = result2.unwrap(); + assert!(conn2.contains("Server=otheraccount.snowflakecomputing.com")); + assert!(!conn2.contains("UID=")); + + // Test with non-existent connection name + let result3 = + resolve_connection_name("Driver=Snowflake;ConnectionName=nonexistent"); + assert!(result3.is_none()); + + // No ConnectionName param → None + let result4 = + resolve_connection_name("Driver=Snowflake;Server=foo"); + assert!(result4.is_none()); + + // Clean up env + std::env::remove_var("SNOWFLAKE_HOME"); + } + + #[test] + fn test_polars_dtype_to_sql() { + assert_eq!(polars_dtype_to_sql(&DataType::Int64), "BIGINT"); + assert_eq!(polars_dtype_to_sql(&DataType::Float64), "DOUBLE PRECISION"); + assert_eq!(polars_dtype_to_sql(&DataType::Boolean), "BOOLEAN"); + assert_eq!(polars_dtype_to_sql(&DataType::Date), "DATE"); + assert_eq!(polars_dtype_to_sql(&DataType::String), "TEXT"); + } } From 27144dd35ac65f5303dfde3406bf3c6175fa9f05 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Wed, 18 Mar 2026 14:02:46 +0000 Subject: [PATCH 03/20] WIP: ODBC --- src/reader/odbc.rs | 85 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/src/reader/odbc.rs b/src/reader/odbc.rs index bab8a0b2..b97b967c 100644 --- a/src/reader/odbc.rs +++ b/src/reader/odbc.rs @@ -35,7 +35,90 @@ pub enum OdbcVariant { SqlServer, } -impl super::SqlDialect for OdbcDialect {} +impl super::SqlDialect for OdbcDialect { + fn sql_list_catalogs(&self) -> String { + match self.variant { + OdbcVariant::Snowflake => { + "SELECT database_name AS catalog_name \ + FROM snowflake.information_schema.databases \ + ORDER BY database_name" + .into() + } + _ => { + "SELECT DISTINCT catalog_name FROM information_schema.schemata \ + ORDER BY catalog_name" + .into() + } + } + } + + fn sql_list_schemas(&self, catalog: &str) -> String { + let catalog = catalog.replace('\'', "''"); + match self.variant { + OdbcVariant::Snowflake => { + let catalog_ident = catalog.replace('"', "\"\""); + format!( + "SELECT schema_name \ + FROM \"{catalog_ident}\".information_schema.schemata \ + ORDER BY schema_name" + ) + } + _ => { + format!( + "SELECT DISTINCT schema_name FROM information_schema.schemata \ + WHERE catalog_name = '{catalog}' ORDER BY schema_name" + ) + } + } + } + + fn sql_list_tables(&self, catalog: &str, schema: &str) -> String { + let schema = schema.replace('\'', "''"); + match self.variant { + OdbcVariant::Snowflake => { + let catalog_ident = catalog.replace('"', "\"\""); + format!( + "SELECT table_name, table_type \ + FROM \"{catalog_ident}\".information_schema.tables \ + WHERE table_schema = '{schema}' \ + ORDER BY table_name" + ) + } + _ => { + let catalog = catalog.replace('\'', "''"); + format!( + "SELECT DISTINCT table_name, table_type FROM information_schema.tables \ + WHERE table_catalog = '{catalog}' AND table_schema = '{schema}' \ + ORDER BY table_name" + ) + } + } + } + + fn sql_list_columns(&self, catalog: &str, schema: &str, table: &str) -> String { + let schema = schema.replace('\'', "''"); + let table = table.replace('\'', "''"); + match self.variant { + OdbcVariant::Snowflake => { + let catalog_ident = catalog.replace('"', "\"\""); + format!( + "SELECT column_name, data_type \ + FROM \"{catalog_ident}\".information_schema.columns \ + WHERE table_schema = '{schema}' AND table_name = '{table}' \ + ORDER BY ordinal_position" + ) + } + _ => { + let catalog = catalog.replace('\'', "''"); + format!( + "SELECT column_name, data_type FROM information_schema.columns \ + WHERE table_catalog = '{catalog}' AND table_schema = '{schema}' \ + AND table_name = '{table}' ORDER BY ordinal_position" + ) + } + } + } +} /// Generic ODBC reader implementing the `Reader` trait. pub struct OdbcReader { From 41a8b9961cec92f74fa3b804aa6ee011cc4746a3 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Wed, 18 Mar 2026 14:37:48 +0000 Subject: [PATCH 04/20] WIP: ODBC --- ggsql-jupyter/src/connection.rs | 12 ++++++++---- src/reader/odbc.rs | 34 +++++++++++---------------------- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/ggsql-jupyter/src/connection.rs b/ggsql-jupyter/src/connection.rs index a1b3085d..0b28c80c 100644 --- a/ggsql-jupyter/src/connection.rs +++ b/ggsql-jupyter/src/connection.rs @@ -65,7 +65,8 @@ fn list_catalogs(reader: &dyn Reader) -> Result, String> { let col = df .column("catalog_name") - .map_err(|e| format!("Missing catalog_name column: {}", e))?; + .or_else(|_| df.column("name")) + .map_err(|e| format!("Missing catalog_name/name column: {}", e))?; let mut catalogs = Vec::new(); for i in 0..df.height() { @@ -88,7 +89,8 @@ fn list_schemas(reader: &dyn Reader, catalog: &str) -> Result, let col = df .column("schema_name") - .map_err(|e| format!("Missing schema_name column: {}", e))?; + .or_else(|_| df.column("name")) + .map_err(|e| format!("Missing schema_name/name column: {}", e))?; let mut schemas = Vec::new(); for i in 0..df.height() { @@ -115,10 +117,12 @@ fn list_tables( let name_col = df .column("table_name") - .map_err(|e| format!("Missing table_name column: {}", e))?; + .or_else(|_| df.column("name")) + .map_err(|e| format!("Missing table_name/name column: {}", e))?; let type_col = df .column("table_type") - .map_err(|e| format!("Missing table_type column: {}", e))?; + .or_else(|_| df.column("kind")) + .map_err(|e| format!("Missing table_type/kind column: {}", e))?; let mut objects = Vec::new(); for i in 0..df.height() { diff --git a/src/reader/odbc.rs b/src/reader/odbc.rs index b97b967c..cb3f06c0 100644 --- a/src/reader/odbc.rs +++ b/src/reader/odbc.rs @@ -38,12 +38,7 @@ pub enum OdbcVariant { impl super::SqlDialect for OdbcDialect { fn sql_list_catalogs(&self) -> String { match self.variant { - OdbcVariant::Snowflake => { - "SELECT database_name AS catalog_name \ - FROM snowflake.information_schema.databases \ - ORDER BY database_name" - .into() - } + OdbcVariant::Snowflake => "SHOW DATABASES".into(), _ => { "SELECT DISTINCT catalog_name FROM information_schema.schemata \ ORDER BY catalog_name" @@ -53,17 +48,13 @@ impl super::SqlDialect for OdbcDialect { } fn sql_list_schemas(&self, catalog: &str) -> String { - let catalog = catalog.replace('\'', "''"); match self.variant { OdbcVariant::Snowflake => { let catalog_ident = catalog.replace('"', "\"\""); - format!( - "SELECT schema_name \ - FROM \"{catalog_ident}\".information_schema.schemata \ - ORDER BY schema_name" - ) + format!("SHOW SCHEMAS IN DATABASE \"{catalog_ident}\"") } _ => { + let catalog = catalog.replace('\'', "''"); format!( "SELECT DISTINCT schema_name FROM information_schema.schemata \ WHERE catalog_name = '{catalog}' ORDER BY schema_name" @@ -73,19 +64,17 @@ impl super::SqlDialect for OdbcDialect { } fn sql_list_tables(&self, catalog: &str, schema: &str) -> String { - let schema = schema.replace('\'', "''"); match self.variant { OdbcVariant::Snowflake => { let catalog_ident = catalog.replace('"', "\"\""); + let schema_ident = schema.replace('"', "\"\""); format!( - "SELECT table_name, table_type \ - FROM \"{catalog_ident}\".information_schema.tables \ - WHERE table_schema = '{schema}' \ - ORDER BY table_name" + "SHOW OBJECTS IN SCHEMA \"{catalog_ident}\".\"{schema_ident}\"" ) } _ => { let catalog = catalog.replace('\'', "''"); + let schema = schema.replace('\'', "''"); format!( "SELECT DISTINCT table_name, table_type FROM information_schema.tables \ WHERE table_catalog = '{catalog}' AND table_schema = '{schema}' \ @@ -96,20 +85,19 @@ impl super::SqlDialect for OdbcDialect { } fn sql_list_columns(&self, catalog: &str, schema: &str, table: &str) -> String { - let schema = schema.replace('\'', "''"); - let table = table.replace('\'', "''"); match self.variant { OdbcVariant::Snowflake => { let catalog_ident = catalog.replace('"', "\"\""); + let schema_ident = schema.replace('"', "\"\""); + let table_ident = table.replace('"', "\"\""); format!( - "SELECT column_name, data_type \ - FROM \"{catalog_ident}\".information_schema.columns \ - WHERE table_schema = '{schema}' AND table_name = '{table}' \ - ORDER BY ordinal_position" + "SHOW COLUMNS IN TABLE \"{catalog_ident}\".\"{schema_ident}\".\"{table_ident}\"" ) } _ => { let catalog = catalog.replace('\'', "''"); + let schema = schema.replace('\'', "''"); + let table = table.replace('\'', "''"); format!( "SELECT column_name, data_type FROM information_schema.columns \ WHERE table_catalog = '{catalog}' AND table_schema = '{schema}' \ From f6e2b4e5fd0bb63a1a64a064c09a232a1c9cfd64 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Wed, 18 Mar 2026 15:15:50 +0000 Subject: [PATCH 05/20] WIP: ODBC --- src/execute/casting.rs | 4 ++-- src/execute/cte.rs | 14 +++++++------- src/naming.rs | 12 ++++++++++++ 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/execute/casting.rs b/src/execute/casting.rs index 5bf26c40..7c499b62 100644 --- a/src/execute/casting.rs +++ b/src/execute/casting.rs @@ -216,7 +216,7 @@ pub fn determine_layer_source( match &layer.source { Some(DataSource::Identifier(name)) => { if materialized_ctes.contains(name) { - naming::cte_table(name) + naming::quote_ident(&naming::cte_table(name)) } else { name.clone() } @@ -227,7 +227,7 @@ pub fn determine_layer_source( None => { // Layer uses global data - caller must ensure has_global is true debug_assert!(has_global, "Layer has no source and no global data"); - naming::global_table() + naming::quote_ident(&naming::global_table()) } } } diff --git a/src/execute/cte.rs b/src/execute/cte.rs index 5a6b665f..e86e7da7 100644 --- a/src/execute/cte.rs +++ b/src/execute/cte.rs @@ -94,7 +94,7 @@ pub fn transform_cte_references(sql: &str, cte_names: &HashSet) -> Strin let mut result = sql.to_string(); for cte_name in cte_names { - let temp_table_name = naming::cte_table(cte_name); + let temp_table_name = naming::quote_ident(&naming::cte_table(cte_name)); // Replace table references: FROM cte_name, JOIN cte_name, cte_name.column // Use word boundary matching to avoid replacing substrings @@ -360,7 +360,7 @@ mod tests { ( "SELECT * FROM sales WHERE year = 2024", vec!["sales"], - vec!["FROM __ggsql_cte_sales_", "__ WHERE year = 2024"], + vec!["FROM \"__ggsql_cte_sales_", "__\" WHERE year = 2024"], None, ), // Multiple CTE references with qualified columns @@ -368,10 +368,10 @@ mod tests { "SELECT sales.date, targets.revenue FROM sales JOIN targets ON sales.id = targets.id", vec!["sales", "targets"], vec![ - "FROM __ggsql_cte_sales_", - "JOIN __ggsql_cte_targets_", - "__ggsql_cte_sales_", // qualified reference sales.date - "__ggsql_cte_targets_", // qualified reference targets.revenue + "FROM \"__ggsql_cte_sales_", + "JOIN \"__ggsql_cte_targets_", + "\"__ggsql_cte_sales_", // qualified reference sales.date + "\"__ggsql_cte_targets_", // qualified reference targets.revenue ], None, ), @@ -379,7 +379,7 @@ mod tests { ( "WHERE sales.date > '2024-01-01' AND sales.revenue > 100", vec!["sales"], - vec!["__ggsql_cte_sales_"], + vec!["\"__ggsql_cte_sales_"], None, ), // No matching CTE (unchanged) diff --git a/src/naming.rs b/src/naming.rs index 882f40dc..cbd551a1 100644 --- a/src/naming.rs +++ b/src/naming.rs @@ -88,6 +88,18 @@ pub const SOURCE_COLUMN: &str = concatcp!(GGSQL_PREFIX, "source", GGSQL_SUFFIX); /// Alias for schema extraction queries pub const SCHEMA_ALIAS: &str = concatcp!(GGSQL_SUFFIX, "schema", GGSQL_SUFFIX); +// ============================================================================ +// Quoting +// ============================================================================ + +/// Quote a SQL identifier with double quotes (ANSI SQL). +/// +/// This ensures case-preserving behavior on backends like Snowflake that +/// uppercase unquoted identifiers. +pub fn quote_ident(name: &str) -> String { + format!("\"{}\"", name.replace('"', "\"\"")) +} + // ============================================================================ // Constructor Functions // ============================================================================ From ae29fa018e458eea96ab1a108750a6a0f20ecec9 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Thu, 19 Mar 2026 13:44:52 +0000 Subject: [PATCH 06/20] ODBC: Quoting is fun --- src/execute/casting.rs | 4 +- src/execute/cte.rs | 2 +- src/execute/layer.rs | 10 ++--- src/execute/schema.rs | 2 +- src/naming.rs | 12 ------ src/plot/layer/geom/bar.rs | 26 ++++++------- src/plot/layer/geom/boxplot.rs | 14 ++++--- src/plot/layer/geom/density.rs | 67 ++++++++++++++++---------------- src/plot/layer/geom/histogram.rs | 12 +++--- src/reader/data.rs | 8 ++-- src/reader/duckdb.rs | 6 +-- src/reader/mod.rs | 12 +++--- 12 files changed, 83 insertions(+), 92 deletions(-) diff --git a/src/execute/casting.rs b/src/execute/casting.rs index 7c499b62..5090c61e 100644 --- a/src/execute/casting.rs +++ b/src/execute/casting.rs @@ -216,7 +216,7 @@ pub fn determine_layer_source( match &layer.source { Some(DataSource::Identifier(name)) => { if materialized_ctes.contains(name) { - naming::quote_ident(&naming::cte_table(name)) + format!("\"{}\"", naming::cte_table(name)) } else { name.clone() } @@ -227,7 +227,7 @@ pub fn determine_layer_source( None => { // Layer uses global data - caller must ensure has_global is true debug_assert!(has_global, "Layer has no source and no global data"); - naming::quote_ident(&naming::global_table()) + format!("\"{}\"", naming::global_table()) } } } diff --git a/src/execute/cte.rs b/src/execute/cte.rs index e86e7da7..983b05e3 100644 --- a/src/execute/cte.rs +++ b/src/execute/cte.rs @@ -94,7 +94,7 @@ pub fn transform_cte_references(sql: &str, cte_names: &HashSet) -> Strin let mut result = sql.to_string(); for cte_name in cte_names { - let temp_table_name = naming::quote_ident(&naming::cte_table(cte_name)); + let temp_table_name = format!("\"{}\"", naming::cte_table(cte_name)); // Replace table references: FROM cte_name, JOIN cte_name, cte_name.column // Use word boundary matching to avoid replacing substrings diff --git a/src/execute/layer.rs b/src/execute/layer.rs index 1ac344e7..c18facd4 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -271,7 +271,7 @@ pub fn apply_pre_stat_transform( .collect(); format!( - "SELECT {} FROM ({}) AS __ggsql_pre__", + "SELECT {} FROM ({}) AS \"__ggsql_pre__\"", select_exprs.join(", "), query ) @@ -315,14 +315,14 @@ pub fn build_layer_base_query( // Build query with optional WHERE clause if let Some(ref f) = layer.filter { format!( - "SELECT {} FROM ({}) AS __ggsql_src__ WHERE {}", + "SELECT {} FROM ({}) AS \"__ggsql_src__\" WHERE {}", select_clause, source_query, f.as_str() ) } else { format!( - "SELECT {} FROM ({}) AS __ggsql_src__", + "SELECT {} FROM ({}) AS \"__ggsql_src__\"", select_clause, source_query ) } @@ -542,14 +542,14 @@ where .and_then(super::cte::split_with_query) { format!( - "{}, __ggsql_stat__ AS ({}) SELECT *, {} FROM __ggsql_stat__", + "{}, \"__ggsql_stat__\" AS ({}) SELECT *, {} FROM \"__ggsql_stat__\"", cte_prefix, trailing_select, stat_rename_exprs.join(", ") ) } else { format!( - "SELECT *, {} FROM ({}) AS __ggsql_stat__", + "SELECT *, {} FROM ({}) AS \"__ggsql_stat__\"", stat_rename_exprs.join(", "), transformed_query ) diff --git a/src/execute/schema.rs b/src/execute/schema.rs index 568c7e81..7c7eace5 100644 --- a/src/execute/schema.rs +++ b/src/execute/schema.rs @@ -30,7 +30,7 @@ pub fn build_minmax_query(source_query: &str, column_names: &[&str]) -> String { .collect(); format!( - "WITH __ggsql_source__ AS ({}) SELECT {} FROM __ggsql_source__ UNION ALL SELECT {} FROM __ggsql_source__", + "WITH \"__ggsql_source__\" AS ({}) SELECT {} FROM \"__ggsql_source__\" UNION ALL SELECT {} FROM \"__ggsql_source__\"", source_query, min_exprs.join(", "), max_exprs.join(", ") diff --git a/src/naming.rs b/src/naming.rs index cbd551a1..882f40dc 100644 --- a/src/naming.rs +++ b/src/naming.rs @@ -88,18 +88,6 @@ pub const SOURCE_COLUMN: &str = concatcp!(GGSQL_PREFIX, "source", GGSQL_SUFFIX); /// Alias for schema extraction queries pub const SCHEMA_ALIAS: &str = concatcp!(GGSQL_SUFFIX, "schema", GGSQL_SUFFIX); -// ============================================================================ -// Quoting -// ============================================================================ - -/// Quote a SQL identifier with double quotes (ANSI SQL). -/// -/// This ensures case-preserving behavior on backends like Snowflake that -/// uppercase unquoted identifiers. -pub fn quote_ident(name: &str) -> String { - format!("\"{}\"", name.replace('"', "\"\"")) -} - // ============================================================================ // Constructor Functions // ============================================================================ diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index 9922d4e6..ee10841e 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -169,19 +169,19 @@ fn stat_bar_count( if let Some(weight_col) = weight_value.column_name() { if schema_columns.contains(weight_col) { // weight column exists - use SUM (but still call it "count") - format!("SUM({}) AS {}", weight_col, stat_count) + format!("SUM({}) AS \"{}\"", weight_col, stat_count) } else { // weight mapped but column doesn't exist - fall back to COUNT // (this shouldn't happen with upfront validation, but handle gracefully) - format!("COUNT(*) AS {}", stat_count) + format!("COUNT(*) AS \"{}\"", stat_count) } } else { // Shouldn't happen (not literal, not column), fall back to COUNT - format!("COUNT(*) AS {}", stat_count) + format!("COUNT(*) AS \"{}\"", stat_count) } } else { // weight not mapped - use COUNT - format!("COUNT(*) AS {}", stat_count) + format!("COUNT(*) AS \"{}\"", stat_count) }; // Build the query based on whether x is mapped or not @@ -191,13 +191,13 @@ fn stat_bar_count( let (grouped_select, final_select) = if group_by.is_empty() { ( format!( - "'{dummy}' AS {x}, {agg}", + "'{dummy}' AS \"{x}\", {agg}", dummy = stat_dummy_value, x = stat_x, agg = agg_expr ), format!( - "*, {count} * 1.0 / SUM({count}) OVER () AS {prop}", + "*, \"{count}\" * 1.0 / SUM(\"{count}\") OVER () AS \"{prop}\"", count = stat_count, prop = stat_proportion ), @@ -206,14 +206,14 @@ fn stat_bar_count( let grp_cols = group_by.join(", "); ( format!( - "{g}, '{dummy}' AS {x}, {agg}", + "{g}, '{dummy}' AS \"{x}\", {agg}", g = grp_cols, dummy = stat_dummy_value, x = stat_x, agg = agg_expr ), format!( - "*, {count} * 1.0 / SUM({count}) OVER (PARTITION BY {grp}) AS {prop}", + "*, \"{count}\" * 1.0 / SUM(\"{count}\") OVER (PARTITION BY {grp}) AS \"{prop}\"", count = stat_count, grp = grp_cols, prop = stat_proportion @@ -224,7 +224,7 @@ fn stat_bar_count( let query_str = if group_by.is_empty() { // No grouping at all - single aggregate format!( - "WITH __stat_src__ AS ({query}), __grouped__ AS (SELECT {grouped} FROM __stat_src__) SELECT {final} FROM __grouped__", + "WITH \"__stat_src__\" AS ({query}), \"__grouped__\" AS (SELECT {grouped} FROM \"__stat_src__\") SELECT {final} FROM \"__grouped__\"", query = query, grouped = grouped_select, final = final_select @@ -233,7 +233,7 @@ fn stat_bar_count( // Group by partition/facet variables only let group_cols = group_by.join(", "); format!( - "WITH __stat_src__ AS ({query}), __grouped__ AS (SELECT {grouped} FROM __stat_src__ GROUP BY {group}) SELECT {final} FROM __grouped__", + "WITH \"__stat_src__\" AS ({query}), \"__grouped__\" AS (SELECT {grouped} FROM \"__stat_src__\" GROUP BY {group}) SELECT {final} FROM \"__grouped__\"", query = query, grouped = grouped_select, group = group_cols, @@ -271,7 +271,7 @@ fn stat_bar_count( ( format!("{x}, {agg}", x = x_col, agg = agg_expr), format!( - "*, {count} * 1.0 / SUM({count}) OVER () AS {prop}", + "*, \"{count}\" * 1.0 / SUM(\"{count}\") OVER () AS \"{prop}\"", count = stat_count, prop = stat_proportion ), @@ -281,7 +281,7 @@ fn stat_bar_count( ( format!("{g}, {x}, {agg}", g = grp_cols, x = x_col, agg = agg_expr), format!( - "*, {count} * 1.0 / SUM({count}) OVER (PARTITION BY {grp}) AS {prop}", + "*, \"{count}\" * 1.0 / SUM(\"{count}\") OVER (PARTITION BY {grp}) AS \"{prop}\"", count = stat_count, grp = grp_cols, prop = stat_proportion @@ -290,7 +290,7 @@ fn stat_bar_count( }; let query_str = format!( - "WITH __stat_src__ AS ({query}), __grouped__ AS (SELECT {grouped} FROM __stat_src__ GROUP BY {group}) SELECT {final} FROM __grouped__", + "WITH \"__stat_src__\" AS ({query}), \"__grouped__\" AS (SELECT {grouped} FROM \"__stat_src__\" GROUP BY {group}) SELECT {final} FROM \"__grouped__\"", query = query, grouped = grouped_select, group = group_cols, diff --git a/src/plot/layer/geom/boxplot.rs b/src/plot/layer/geom/boxplot.rs index d94eef0a..e9d90723 100644 --- a/src/plot/layer/geom/boxplot.rs +++ b/src/plot/layer/geom/boxplot.rs @@ -177,6 +177,8 @@ fn boxplot_sql_compute_summary( let q1 = dialect.sql_percentile(value, 0.25, from, groups); let median = dialect.sql_percentile(value, 0.50, from, groups); let q3 = dialect.sql_percentile(value, 0.75, from, groups); + let qt = "\"__ggsql_qt__\""; + let fn_alias = "\"__ggsql_fn__\""; format!( "SELECT *, @@ -190,10 +192,10 @@ fn boxplot_sql_compute_summary( {q1} AS q1, {median} AS median, {q3} AS q3 - FROM ({from}) AS __ggsql_qt__ + FROM ({from}) AS {qt} WHERE {value} IS NOT NULL GROUP BY {groups} - ) AS __ggsql_fn__", + ) AS {fn_alias}", lower_expr = lower_expr, upper_expr = upper_expr, groups = groups_str, @@ -397,10 +399,10 @@ mod tests { {q1} AS q1, {median} AS median, {q3} AS q3 - FROM (SELECT * FROM sales) AS __ggsql_qt__ + FROM (SELECT * FROM sales) AS "__ggsql_qt__" WHERE price IS NOT NULL GROUP BY category - ) AS __ggsql_fn__"# + ) AS "__ggsql_fn__""# ); assert_eq!(result, expected); @@ -433,10 +435,10 @@ mod tests { {q1} AS q1, {median} AS median, {q3} AS q3 - FROM (SELECT * FROM data) AS __ggsql_qt__ + FROM (SELECT * FROM data) AS "__ggsql_qt__" WHERE revenue IS NOT NULL GROUP BY region, product - ) AS __ggsql_fn__"# + ) AS "__ggsql_fn__""# ); assert_eq!(result, expected); diff --git a/src/plot/layer/geom/density.rs b/src/plot/layer/geom/density.rs index c83d2ee4..2a11e01b 100644 --- a/src/plot/layer/geom/density.rs +++ b/src/plot/layer/geom/density.rs @@ -272,7 +272,7 @@ fn density_sql_bandwidth( SELECT {rule} AS bw{comma} {groups_str} - FROM ({from}) AS __ggsql_qt__ + FROM ({from}) AS {qt} WHERE {value} IS NOT NULL {group_by} )", @@ -282,6 +282,7 @@ fn density_sql_bandwidth( groups_str = groups_str, comma = comma, from = from, + qt = "\"__ggsql_qt__\"", ) } @@ -408,8 +409,8 @@ fn build_grid_cte( if !has_groups { return format!( "{seq}, grid AS ( - SELECT {min} + (__ggsql_seq__.n * {diff} / {n_points}) AS x - FROM __ggsql_seq__ + SELECT {min} + (\"__ggsql_seq__\".n * {diff} / {n_points}) AS x + FROM \"__ggsql_seq__\" )", seq = seq, min = min, @@ -423,8 +424,8 @@ fn build_grid_cte( "{seq}, grid AS ( SELECT {groups}, - {min} + (__ggsql_seq__.n * {diff} / {n_points}) AS x - FROM __ggsql_seq__ + {min} + (\"__ggsql_seq__\".n * {diff} / {n_points}) AS x + FROM \"__ggsql_seq__\" CROSS JOIN (SELECT DISTINCT {groups} FROM ({from})) AS groups )", seq = seq, @@ -494,16 +495,16 @@ fn compute_density( {data_cte}, {grid_cte} SELECT - {x_column}, + \"{x_column}\", {groups} - {intensity_column}, - {intensity_column} / __norm AS {density_column} + \"{intensity_column}\", + \"{intensity_column}\" / \"__norm\" AS \"{density_column}\" FROM ( SELECT - grid.x AS {x_column}, + grid.x AS \"{x_column}\", {grid_groups} - {kernel} AS {intensity_column}, - SUM(data.weight) AS __norm + {kernel} AS \"{intensity_column}\", + SUM(data.weight) AS \"__norm\" {join_logic} {aggregation} )", @@ -543,32 +544,32 @@ mod tests { let kernel = choose_kde_kernel(¶meters).expect("kernel should be valid"); let sql = compute_density("x", &groups, kernel, &bw_cte, &data_cte, &grid_cte); - let expected = "WITH RECURSIVE bandwidth AS (SELECT 0.5 AS bw), + let expected = r#"WITH RECURSIVE bandwidth AS (SELECT 0.5 AS bw), data AS ( SELECT x AS val, 1.0 AS weight FROM (SELECT x FROM (VALUES (1.0), (2.0), (3.0)) AS t(x)) WHERE x IS NOT NULL ), - __ggsql_base__(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM __ggsql_base__ WHERE n < 7),__ggsql_seq__(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM __ggsql_base__ a, __ggsql_base__ b, __ggsql_base__ c WHERE a.n * 64 + b.n * 8 + c.n < 512), + "__ggsql_base__"(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM "__ggsql_base__" WHERE n < 7),"__ggsql_seq__"(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM "__ggsql_base__" a, "__ggsql_base__" b, "__ggsql_base__" c WHERE a.n * 64 + b.n * 8 + c.n < 512), grid AS ( - SELECT -0.5 + (__ggsql_seq__.n * 11 / 511) AS x - FROM __ggsql_seq__ + SELECT -0.5 + ("__ggsql_seq__".n * 11 / 511) AS x + FROM "__ggsql_seq__" ) SELECT - __ggsql_stat_x, - __ggsql_stat_intensity, - __ggsql_stat_intensity / __norm AS __ggsql_stat_density + "__ggsql_stat_x", + "__ggsql_stat_intensity", + "__ggsql_stat_intensity" / "__norm" AS "__ggsql_stat_density" FROM ( SELECT - grid.x AS __ggsql_stat_x, - SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS __ggsql_stat_intensity, - SUM(data.weight) AS __norm + grid.x AS "__ggsql_stat_x", + SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS "__ggsql_stat_intensity", + SUM(data.weight) AS "__norm" FROM data INNER JOIN bandwidth ON true CROSS JOIN grid GROUP BY grid.x ORDER BY grid.x - )"; + )"#; // Normalize whitespace for comparison let normalize = |s: &str| s.split_whitespace().collect::>().join(" "); @@ -606,38 +607,38 @@ mod tests { let kernel = choose_kde_kernel(¶meters).expect("kernel should be valid"); let sql = compute_density("x", &groups, kernel, &bw_cte, &data_cte, &grid_cte); - let expected = "WITH RECURSIVE bandwidth AS (SELECT 0.5 AS bw, region, category FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) GROUP BY region, category), + let expected = r#"WITH RECURSIVE bandwidth AS (SELECT 0.5 AS bw, region, category FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) GROUP BY region, category), data AS ( SELECT region, category, x AS val, 1.0 AS weight FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) WHERE x IS NOT NULL ), - __ggsql_base__(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM __ggsql_base__ WHERE n < 7),__ggsql_seq__(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM __ggsql_base__ a, __ggsql_base__ b, __ggsql_base__ c WHERE a.n * 64 + b.n * 8 + c.n < 512), + "__ggsql_base__"(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM "__ggsql_base__" WHERE n < 7),"__ggsql_seq__"(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM "__ggsql_base__" a, "__ggsql_base__" b, "__ggsql_base__" c WHERE a.n * 64 + b.n * 8 + c.n < 512), grid AS ( SELECT region, category, - -11 + (__ggsql_seq__.n * 22 / 511) AS x - FROM __ggsql_seq__ + -11 + ("__ggsql_seq__".n * 22 / 511) AS x + FROM "__ggsql_seq__" CROSS JOIN (SELECT DISTINCT region, category FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category))) AS groups ) SELECT - __ggsql_stat_x, + "__ggsql_stat_x", region, category, - __ggsql_stat_intensity, - __ggsql_stat_intensity / __norm AS __ggsql_stat_density + "__ggsql_stat_intensity", + "__ggsql_stat_intensity" / "__norm" AS "__ggsql_stat_density" FROM ( SELECT - grid.x AS __ggsql_stat_x, + grid.x AS "__ggsql_stat_x", grid.region, grid.category, - SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS __ggsql_stat_intensity, - SUM(data.weight) AS __norm + SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS "__ggsql_stat_intensity", + SUM(data.weight) AS "__norm" FROM data INNER JOIN bandwidth ON data.region IS NOT DISTINCT FROM bandwidth.region AND data.category IS NOT DISTINCT FROM bandwidth.category CROSS JOIN grid WHERE grid.region IS NOT DISTINCT FROM data.region AND grid.category IS NOT DISTINCT FROM data.category GROUP BY grid.x, grid.region, grid.category ORDER BY grid.x, grid.region, grid.category - )"; + )"#; // Normalize whitespace for comparison let normalize = |s: &str| s.split_whitespace().collect::>().join(" "); diff --git a/src/plot/layer/geom/histogram.rs b/src/plot/layer/geom/histogram.rs index b372d79e..90a9dddf 100644 --- a/src/plot/layer/geom/histogram.rs +++ b/src/plot/layer/geom/histogram.rs @@ -144,7 +144,7 @@ fn stat_histogram( // Query min/max to compute bin width let stats_query = format!( - "SELECT MIN({x}) as min_val, MAX({x}) as max_val FROM ({query}) AS __ggsql_stats__", + "SELECT MIN({x}) as min_val, MAX({x}) as max_val FROM ({query}) AS \"__ggsql_stats__\"", x = x_col, query = query ); @@ -227,11 +227,11 @@ fn stat_histogram( let (binned_select, final_select) = if group_by.is_empty() { ( format!( - "{} AS {}, {} AS {}, {} AS {}", + "{} AS \"{}\", {} AS \"{}\", {} AS \"{}\"", bin_expr, stat_bin, bin_end_expr, stat_bin_end, agg_expr, stat_count ), format!( - "*, {count} * 1.0 / SUM({count}) OVER () AS {density}", + "*, \"{count}\" * 1.0 / SUM(\"{count}\") OVER () AS \"{density}\"", count = stat_count, density = stat_density ), @@ -240,11 +240,11 @@ fn stat_histogram( let grp_cols = group_by.join(", "); ( format!( - "{}, {} AS {}, {} AS {}, {} AS {}", + "{}, {} AS \"{}\", {} AS \"{}\", {} AS \"{}\"", grp_cols, bin_expr, stat_bin, bin_end_expr, stat_bin_end, agg_expr, stat_count ), format!( - "*, {count} * 1.0 / SUM({count}) OVER (PARTITION BY {grp}) AS {density}", + "*, \"{count}\" * 1.0 / SUM(\"{count}\") OVER (PARTITION BY {grp}) AS \"{density}\"", count = stat_count, grp = grp_cols, density = stat_density @@ -253,7 +253,7 @@ fn stat_histogram( }; let transformed_query = format!( - "WITH __stat_src__ AS ({query}), __binned__ AS (SELECT {binned} FROM __stat_src__ GROUP BY {group}) SELECT {final} FROM __binned__", + "WITH \"__stat_src__\" AS ({query}), \"__binned__\" AS (SELECT {binned} FROM \"__stat_src__\" GROUP BY {group}) SELECT {final} FROM \"__binned__\"", query = query, binned = binned_select, group = group_cols, diff --git a/src/reader/data.rs b/src/reader/data.rs index bb62b128..30b9ca6e 100644 --- a/src/reader/data.rs +++ b/src/reader/data.rs @@ -185,7 +185,7 @@ pub fn rewrite_namespaced_sql(sql: &str) -> Result { replacements.push(( node.start_byte(), node.end_byte(), - naming::builtin_data_table(name), + format!("\"{}\"", naming::builtin_data_table(name)), )); } } @@ -315,7 +315,7 @@ mod tests { fn test_rewrite_namespaced_sql_simple() { let sql = "SELECT * FROM ggsql:penguins"; let rewritten = rewrite_namespaced_sql(sql).unwrap(); - assert_eq!(rewritten, "SELECT * FROM __ggsql_data_penguins__"); + assert_eq!(rewritten, "SELECT * FROM \"__ggsql_data_penguins__\""); } #[test] @@ -324,7 +324,7 @@ mod tests { let rewritten = rewrite_namespaced_sql(sql).unwrap(); assert_eq!( rewritten, - "SELECT * FROM __ggsql_data_penguins__ p, __ggsql_data_airquality__ a WHERE p.id = a.id" + "SELECT * FROM \"__ggsql_data_penguins__\" p, \"__ggsql_data_airquality__\" a WHERE p.id = a.id" ); } @@ -339,7 +339,7 @@ mod tests { fn test_rewrite_namespaced_sql_with_visualise() { let sql = "SELECT * FROM ggsql:penguins VISUALISE DRAW point MAPPING bill_len AS x, bill_dep AS y"; let rewritten = rewrite_namespaced_sql(sql).unwrap(); - assert!(rewritten.starts_with("SELECT * FROM __ggsql_data_penguins__")); + assert!(rewritten.starts_with("SELECT * FROM \"__ggsql_data_penguins__\"")); assert!(!rewritten.contains("ggsql:")); } } diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index 3a0cf616..3ac23b24 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -36,7 +36,7 @@ impl super::SqlDialect for DuckDbDialect { fn sql_generate_series(&self, n: usize) -> String { format!( - "__ggsql_seq__(n) AS (SELECT generate_series FROM GENERATE_SERIES(0, {}))", + "\"__ggsql_seq__\"(n) AS (SELECT generate_series FROM GENERATE_SERIES(0, {}))", n - 1 ) } @@ -44,13 +44,13 @@ impl super::SqlDialect for DuckDbDialect { fn sql_percentile(&self, column: &str, fraction: f64, from: &str, groups: &[String]) -> String { let group_filter = groups .iter() - .map(|g| format!("AND __ggsql_pct__.{g} IS NOT DISTINCT FROM __ggsql_qt__.{g}")) + .map(|g| format!("AND \"__ggsql_pct__\".{g} IS NOT DISTINCT FROM \"__ggsql_qt__\".{g}")) .collect::>() .join(" "); format!( "(SELECT QUANTILE_CONT({column}, {fraction}) \ - FROM ({from}) AS __ggsql_pct__ \ + FROM ({from}) AS \"__ggsql_pct__\" \ WHERE {column} IS NOT NULL {group_filter})" ) } diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 02a79c63..512be7eb 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -167,12 +167,12 @@ pub trait SqlDialect { let base_sq = base_size * base_size; let base_max = base_size - 1; format!( - "__ggsql_base__(n) AS (\ - SELECT 0 UNION ALL SELECT n + 1 FROM __ggsql_base__ WHERE n < {base_max}\ + "\"__ggsql_base__\"(n) AS (\ + SELECT 0 UNION ALL SELECT n + 1 FROM \"__ggsql_base__\" WHERE n < {base_max}\ ),\ - __ggsql_seq__(n) AS (\ + \"__ggsql_seq__\"(n) AS (\ SELECT CAST(a.n * {base_sq} + b.n * {base_size} + c.n AS REAL) AS n \ - FROM __ggsql_base__ a, __ggsql_base__ b, __ggsql_base__ c \ + FROM \"__ggsql_base__\" a, \"__ggsql_base__\" b, \"__ggsql_base__\" c \ WHERE a.n * {base_sq} + b.n * {base_size} + c.n < {n}\ )" ) @@ -186,7 +186,7 @@ pub trait SqlDialect { // Uses NTILE(4) to divide data into quartiles, then interpolates between boundaries. let group_filter = groups .iter() - .map(|g| format!("AND __ggsql_pct__.{g} IS NOT DISTINCT FROM __ggsql_qt__.{g}")) + .map(|g| format!("AND \"__ggsql_pct__\".{g} IS NOT DISTINCT FROM \"__ggsql_qt__\".{g}")) .collect::>() .join(" "); @@ -201,7 +201,7 @@ pub trait SqlDialect { FROM (\ SELECT {column} AS __val, \ NTILE(4) OVER (ORDER BY {column}) AS __tile \ - FROM ({from}) AS __ggsql_pct__ \ + FROM ({from}) AS \"__ggsql_pct__\" \ WHERE {column} IS NOT NULL {group_filter}\ ))" ) From 5d68d92d341f374ba77b79c75662337ec1f7cf87 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Fri, 20 Mar 2026 10:29:28 +0000 Subject: [PATCH 07/20] WIP: Quoting is more fun --- src/plot/layer/geom/boxplot.rs | 65 +++++++++++++------------ src/plot/layer/geom/density.rs | 62 +++++++++++++---------- src/plot/scale/scale_type/binned.rs | 55 +++++++++++---------- src/plot/scale/scale_type/continuous.rs | 18 ++++--- src/plot/scale/scale_type/discrete.rs | 5 +- src/plot/scale/scale_type/ordinal.rs | 5 +- src/reader/duckdb.rs | 6 ++- src/reader/mod.rs | 6 ++- 8 files changed, 123 insertions(+), 99 deletions(-) diff --git a/src/plot/layer/geom/boxplot.rs b/src/plot/layer/geom/boxplot.rs index e9d90723..b87b2a21 100644 --- a/src/plot/layer/geom/boxplot.rs +++ b/src/plot/layer/geom/boxplot.rs @@ -171,7 +171,8 @@ fn boxplot_sql_compute_summary( coef: &f64, dialect: &dyn SqlDialect, ) -> String { - let groups_str = groups.join(", "); + let quoted_groups: Vec = groups.iter().map(|g| format!("\"{}\"", g)).collect(); + let groups_str = quoted_groups.join(", "); let lower_expr = dialect.sql_greatest(&[&format!("q1 - {coef} * (q3 - q1)"), "min"]); let upper_expr = dialect.sql_least(&[&format!("q3 + {coef} * (q3 - q1)"), "max"]); let q1 = dialect.sql_percentile(value, 0.25, from, groups); @@ -179,6 +180,7 @@ fn boxplot_sql_compute_summary( let q3 = dialect.sql_percentile(value, 0.75, from, groups); let qt = "\"__ggsql_qt__\""; let fn_alias = "\"__ggsql_fn__\""; + let quoted_value = format!("\"{}\"", value); format!( "SELECT *, @@ -199,7 +201,7 @@ fn boxplot_sql_compute_summary( lower_expr = lower_expr, upper_expr = upper_expr, groups = groups_str, - value = value, + value = quoted_value, from = from, q1 = q1, median = median, @@ -211,10 +213,12 @@ fn boxplot_sql_filter_outliers(groups: &[String], value: &str, from: &str) -> St let mut join_pairs = Vec::new(); let mut keep_columns = Vec::new(); for column in groups { - join_pairs.push(format!("raw.{} = summary.{}", column, column)); - keep_columns.push(format!("raw.{}", column)); + let quoted = format!("\"{}\"", column); + join_pairs.push(format!("raw.{} = summary.{}", quoted, quoted)); + keep_columns.push(format!("raw.{}", quoted)); } + let quoted_value = format!("\"{}\"", value); // We're joining outliers with the summary to use the lower/upper whisker // values as a filter format!( @@ -225,7 +229,7 @@ fn boxplot_sql_filter_outliers(groups: &[String], value: &str, from: &str) -> St FROM ({from}) raw JOIN summary ON {pairs} WHERE raw.{value} NOT BETWEEN summary.lower AND summary.upper", - value = value, + value = quoted_value, groups = keep_columns.join(", "), pairs = join_pairs.join(" AND "), from = from @@ -243,7 +247,8 @@ fn boxplot_sql_append_outliers( let value2_name = naming::stat_column("value2"); let type_name = naming::stat_column("type"); - let groups_str = groups.join(", "); + let quoted_groups: Vec = groups.iter().map(|g| format!("\"{}\"", g)).collect(); + let groups_str = quoted_groups.join(", "); // Helper to build visual-element rows from summary table // Each row type maps to one visual element with y and yend where needed @@ -326,14 +331,14 @@ mod tests { fn test_sql_compute_summary_basic() { let groups = vec!["category".to_string()]; let result = boxplot_sql_compute_summary("data", &groups, "value", &1.5, &AnsiDialect); - assert!(result.contains("NTILE(4) OVER (ORDER BY value)")); + assert!(result.contains("NTILE(4) OVER (ORDER BY \"value\")")); assert!(result.contains("AS q1")); assert!(result.contains("AS median")); assert!(result.contains("AS q3")); - assert!(result.contains("MIN(value) AS min")); - assert!(result.contains("MAX(value) AS max")); - assert!(result.contains("WHERE value IS NOT NULL")); - assert!(result.contains("GROUP BY category")); + assert!(result.contains("MIN(\"value\") AS min")); + assert!(result.contains("MAX(\"value\") AS max")); + assert!(result.contains("WHERE \"value\" IS NOT NULL")); + assert!(result.contains("GROUP BY \"category\"")); assert!(result.contains("CASE WHEN (q1 - 1.5")); assert!(result.contains("CASE WHEN (q3 + 1.5")); } @@ -342,8 +347,8 @@ mod tests { fn test_sql_compute_summary_multiple_groups() { let groups = vec!["cat".to_string(), "region".to_string()]; let result = boxplot_sql_compute_summary("tbl", &groups, "val", &1.5, &AnsiDialect); - assert!(result.contains("GROUP BY cat, region")); - assert!(result.contains("NTILE(4) OVER (ORDER BY val)")); + assert!(result.contains("GROUP BY \"cat\", \"region\"")); + assert!(result.contains("NTILE(4) OVER (ORDER BY \"val\")")); } #[test] @@ -364,8 +369,8 @@ mod tests { let groups = vec!["cat".to_string(), "region".to_string()]; let result = boxplot_sql_filter_outliers(&groups, "value", "raw_data"); assert!(result.contains("JOIN summary ON")); - assert!(result.contains("raw.cat = summary.cat")); - assert!(result.contains("raw.region = summary.region")); + assert!(result.contains("raw.\"cat\" = summary.\"cat\"")); + assert!(result.contains("raw.\"region\" = summary.\"region\"")); assert!(result.contains("NOT BETWEEN summary.lower AND summary.upper")); assert!(result.contains("'outlier' AS type")); } @@ -393,15 +398,15 @@ mod tests { (CASE WHEN (q3 + 1.5 * (q3 - q1)) <= (max) THEN (q3 + 1.5 * (q3 - q1)) ELSE (max) END) AS upper FROM ( SELECT - category, - MIN(price) AS min, - MAX(price) AS max, + "category", + MIN("price") AS min, + MAX("price") AS max, {q1} AS q1, {median} AS median, {q3} AS q3 FROM (SELECT * FROM sales) AS "__ggsql_qt__" - WHERE price IS NOT NULL - GROUP BY category + WHERE "price" IS NOT NULL + GROUP BY "category" ) AS "__ggsql_fn__""# ); @@ -429,15 +434,15 @@ mod tests { (CASE WHEN (q3 + 1.5 * (q3 - q1)) <= (max) THEN (q3 + 1.5 * (q3 - q1)) ELSE (max) END) AS upper FROM ( SELECT - region, product, - MIN(revenue) AS min, - MAX(revenue) AS max, + "region", "product", + MIN("revenue") AS min, + MAX("revenue") AS max, {q1} AS q1, {median} AS median, {q3} AS q3 FROM (SELECT * FROM data) AS "__ggsql_qt__" - WHERE revenue IS NOT NULL - GROUP BY region, product + WHERE "revenue" IS NOT NULL + GROUP BY "region", "product" ) AS "__ggsql_fn__""# ); @@ -501,8 +506,8 @@ mod tests { let raw = "(SELECT * FROM raw_data)"; let result = boxplot_sql_append_outliers(summary, &groups, "val", raw, &true); - // Verify all groups are present - assert!(result.contains("cat, region, year")); + // Verify all groups are present (quoted) + assert!(result.contains("\"cat\", \"region\", \"year\"")); // Check structure assert!(result.contains("WITH")); @@ -511,9 +516,9 @@ mod tests { // Verify outlier join conditions for all groups let outlier_section = result.split("outliers AS").nth(1).unwrap(); - assert!(outlier_section.contains("raw.cat = summary.cat")); - assert!(outlier_section.contains("raw.region = summary.region")); - assert!(outlier_section.contains("raw.year = summary.year")); + assert!(outlier_section.contains("raw.\"cat\" = summary.\"cat\"")); + assert!(outlier_section.contains("raw.\"region\" = summary.\"region\"")); + assert!(outlier_section.contains("raw.\"year\" = summary.\"year\"")); } // ==================== Parameter Validation Tests ==================== diff --git a/src/plot/layer/geom/density.rs b/src/plot/layer/geom/density.rs index 2a11e01b..77f6497a 100644 --- a/src/plot/layer/geom/density.rs +++ b/src/plot/layer/geom/density.rs @@ -182,13 +182,14 @@ fn compute_range_sql( from: &str, execute: &dyn Fn(&str) -> crate::Result, ) -> Result<(f64, f64)> { + let quoted_value = format!("\"{}\"", value); let query = format!( "SELECT MIN({value}) AS min, MAX({value}) AS max FROM ({from}) WHERE {value} IS NOT NULL", - value = value, + value = quoted_value, from = from ); let result = execute(&query)?; @@ -234,7 +235,8 @@ fn density_sql_bandwidth( ) -> String { let mut group_by = String::new(); let mut comma = String::new(); - let groups_str = groups.join(", "); + let quoted_groups: Vec = groups.iter().map(|g| format!("\"{}\"", g)).collect(); + let groups_str = quoted_groups.join(", "); if !groups_str.is_empty() { group_by = format!("GROUP BY {}", groups_str); @@ -266,6 +268,7 @@ fn density_sql_bandwidth( }; return cte; } + let quoted_value = format!("\"{}\"", value); format!( "WITH RECURSIVE bandwidth AS ( @@ -277,7 +280,7 @@ fn density_sql_bandwidth( {group_by} )", rule = silverman_rule(adjust, value, from, groups, dialect), - value = value, + value = quoted_value, group_by = group_by, groups_str = groups_str, comma = comma, @@ -296,7 +299,8 @@ fn silverman_rule( // The query computes Silverman's rule of thumb (R's `stats::bw.nrd0()`). // We absorb the adjustment in the 0.9 multiplier of the rule let adjust = 0.9 * adjust; - let stddev = format!("SQRT(AVG({v}*{v}) - AVG({v})*AVG({v}))", v = value_column); + let v = format!("\"{}\"", value_column); + let stddev = format!("SQRT(AVG({v}*{v}) - AVG({v})*AVG({v}))", v = v); let q75 = dialect.sql_percentile(value_column, 0.75, from, groups); let q25 = dialect.sql_percentile(value_column, 0.25, from, groups); let iqr = format!("({q75} - {q25}) / 1.34"); @@ -364,22 +368,24 @@ fn choose_kde_kernel(parameters: &HashMap) -> Result, from: &str, group_by: &[String]) -> String { // Include weight column if provided, otherwise default to 1.0 let weight_col = if let Some(w) = weight { - format!(", {} AS weight", w) + format!(", \"{}\" AS weight", w) } else { ", 1.0 AS weight".to_string() }; + let quoted_value = format!("\"{}\"", value); // Only filter out nulls in value column, keep NULLs in group columns - let filter_valid = format!("{} IS NOT NULL", value); + let filter_valid = format!("{} IS NOT NULL", quoted_value); + let quoted_groups: Vec = group_by.iter().map(|g| format!("\"{}\"", g)).collect(); format!( "data AS ( SELECT {groups}{value} AS val{weight_col} FROM ({from}) WHERE {filter_valid} )", - groups = with_trailing_comma(&group_by.join(", ")), - value = value, + groups = with_trailing_comma("ed_groups.join(", ")), + value = quoted_value, weight_col = weight_col, from = from, filter_valid = filter_valid @@ -419,7 +425,8 @@ fn build_grid_cte( ); } - let groups = groups.join(", "); + let quoted_groups: Vec = groups.iter().map(|g| format!("\"{}\"", g)).collect(); + let groups = quoted_groups.join(", "); format!( "{seq}, grid AS ( SELECT @@ -451,7 +458,7 @@ fn compute_density( } else { group_by .iter() - .map(|g| format!("data.{col} IS NOT DISTINCT FROM bandwidth.{col}", col = g)) + .map(|g| format!("data.\"{}\" IS NOT DISTINCT FROM bandwidth.\"{}\"", g, g)) .collect::>() .join(" AND ") }; @@ -462,7 +469,7 @@ fn compute_density( } else { let grid_data_conds: Vec = group_by .iter() - .map(|g| format!("grid.{col} IS NOT DISTINCT FROM data.{col}", col = g)) + .map(|g| format!("grid.\"{}\" IS NOT DISTINCT FROM data.\"{}\"", g, g)) .collect(); format!("WHERE {}", grid_data_conds.join(" AND ")) }; @@ -476,7 +483,7 @@ fn compute_density( ); // Build group-related SQL fragments - let grid_groups: Vec = group_by.iter().map(|g| format!("grid.{}", g)).collect(); + let grid_groups: Vec = group_by.iter().map(|g| format!("grid.\"{}\"", g)).collect(); let aggregation = format!( "GROUP BY grid.x{grid_group_by} ORDER BY grid.x{grid_group_by}", @@ -486,7 +493,8 @@ fn compute_density( let groups = if group_by.is_empty() { String::new() } else { - format!("{},", group_by.join(", ")) + let quoted: Vec = group_by.iter().map(|g| format!("\"{}\"", g)).collect(); + format!("{},", quoted.join(", ")) }; // Generate the density computation query @@ -546,9 +554,9 @@ mod tests { let expected = r#"WITH RECURSIVE bandwidth AS (SELECT 0.5 AS bw), data AS ( - SELECT x AS val, 1.0 AS weight + SELECT "x" AS val, 1.0 AS weight FROM (SELECT x FROM (VALUES (1.0), (2.0), (3.0)) AS t(x)) - WHERE x IS NOT NULL + WHERE "x" IS NOT NULL ), "__ggsql_base__"(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM "__ggsql_base__" WHERE n < 7),"__ggsql_seq__"(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM "__ggsql_base__" a, "__ggsql_base__" b, "__ggsql_base__" c WHERE a.n * 64 + b.n * 8 + c.n < 512), grid AS ( @@ -607,37 +615,37 @@ mod tests { let kernel = choose_kde_kernel(¶meters).expect("kernel should be valid"); let sql = compute_density("x", &groups, kernel, &bw_cte, &data_cte, &grid_cte); - let expected = r#"WITH RECURSIVE bandwidth AS (SELECT 0.5 AS bw, region, category FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) GROUP BY region, category), + let expected = r#"WITH RECURSIVE bandwidth AS (SELECT 0.5 AS bw, "region", "category" FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) GROUP BY "region", "category"), data AS ( - SELECT region, category, x AS val, 1.0 AS weight + SELECT "region", "category", "x" AS val, 1.0 AS weight FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) - WHERE x IS NOT NULL + WHERE "x" IS NOT NULL ), "__ggsql_base__"(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM "__ggsql_base__" WHERE n < 7),"__ggsql_seq__"(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM "__ggsql_base__" a, "__ggsql_base__" b, "__ggsql_base__" c WHERE a.n * 64 + b.n * 8 + c.n < 512), grid AS ( SELECT - region, category, + "region", "category", -11 + ("__ggsql_seq__".n * 22 / 511) AS x FROM "__ggsql_seq__" - CROSS JOIN (SELECT DISTINCT region, category FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category))) AS groups + CROSS JOIN (SELECT DISTINCT "region", "category" FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category))) AS groups ) SELECT "__ggsql_stat_x", - region, category, + "region", "category", "__ggsql_stat_intensity", "__ggsql_stat_intensity" / "__norm" AS "__ggsql_stat_density" FROM ( SELECT grid.x AS "__ggsql_stat_x", - grid.region, grid.category, + grid."region", grid."category", SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS "__ggsql_stat_intensity", SUM(data.weight) AS "__norm" FROM data - INNER JOIN bandwidth ON data.region IS NOT DISTINCT FROM bandwidth.region AND data.category IS NOT DISTINCT FROM bandwidth.category + INNER JOIN bandwidth ON data."region" IS NOT DISTINCT FROM bandwidth."region" AND data."category" IS NOT DISTINCT FROM bandwidth."category" CROSS JOIN grid - WHERE grid.region IS NOT DISTINCT FROM data.region AND grid.category IS NOT DISTINCT FROM data.category - GROUP BY grid.x, grid.region, grid.category - ORDER BY grid.x, grid.region, grid.category + WHERE grid."region" IS NOT DISTINCT FROM data."region" AND grid."category" IS NOT DISTINCT FROM data."category" + GROUP BY grid.x, grid."region", grid."category" + ORDER BY grid.x, grid."region", grid."category" )"#; // Normalize whitespace for comparison @@ -718,7 +726,7 @@ mod tests { // Verify SQL uses NTILE-based percentile subqueries with grouping assert!(bw_cte.contains("NTILE(4)")); - assert!(bw_cte.contains("GROUP BY region")); + assert!(bw_cte.contains("GROUP BY \"region\"")); let expected_rule = silverman_rule(1.0, "x", query, &groups, &AnsiDialect); assert!(normalize(&bw_cte).contains(&normalize(&expected_rule))); diff --git a/src/plot/scale/scale_type/binned.rs b/src/plot/scale/scale_type/binned.rs index 25edd0b3..2c32d650 100644 --- a/src/plot/scale/scale_type/binned.rs +++ b/src/plot/scale/scale_type/binned.rs @@ -692,20 +692,21 @@ fn build_bin_condition( (if is_first { ">=" } else { ">" }, "<=") }; + let quoted = format!("\"{}\"", column_name); if oob_squish && is_first && is_last { // Single bin with squish: capture everything "TRUE".to_string() } else if oob_squish && is_first { // First bin with squish: no lower bound, extends to -∞ - format!("{} {} {}", column_name, upper_op, upper_expr) + format!("{} {} {}", quoted, upper_op, upper_expr) } else if oob_squish && is_last { // Last bin with squish: no upper bound, extends to +∞ - format!("{} {} {}", column_name, lower_op, lower_expr) + format!("{} {} {}", quoted, lower_op, lower_expr) } else { // Normal bin with both bounds format!( "{} {} {} AND {} {} {}", - column_name, lower_op, lower_expr, column_name, upper_op, upper_expr + quoted, lower_op, lower_expr, quoted, upper_op, upper_expr ) } } @@ -820,10 +821,10 @@ mod tests { // Should produce CASE WHEN with bin centers 5, 15, 25 assert!(sql.contains("CASE")); - assert!(sql.contains("WHEN value >= 0 AND value < 10 THEN 5")); - assert!(sql.contains("WHEN value >= 10 AND value < 20 THEN 15")); + assert!(sql.contains("WHEN \"value\" >= 0 AND \"value\" < 10 THEN 5")); + assert!(sql.contains("WHEN \"value\" >= 10 AND \"value\" < 20 THEN 15")); // Last bin should be inclusive on both ends - assert!(sql.contains("WHEN value >= 20 AND value <= 30 THEN 25")); + assert!(sql.contains("WHEN \"value\" >= 20 AND \"value\" <= 30 THEN 25")); assert!(sql.contains("ELSE NULL END")); } @@ -871,8 +872,8 @@ mod tests { .unwrap(); // closed="left": [lower, upper) except last which is [lower, upper] - assert!(sql.contains("col >= 0 AND col < 10")); - assert!(sql.contains("col >= 10 AND col <= 20")); // last bin inclusive + assert!(sql.contains("\"col\" >= 0 AND \"col\" < 10")); + assert!(sql.contains("\"col\" >= 10 AND \"col\" <= 20")); // last bin inclusive } #[test] @@ -897,8 +898,8 @@ mod tests { .unwrap(); // closed="right": first bin is [lower, upper], rest are (lower, upper] - assert!(sql.contains("col >= 0 AND col <= 10")); // first bin inclusive - assert!(sql.contains("col > 10 AND col <= 20")); + assert!(sql.contains("\"col\" >= 0 AND \"col\" <= 10")); // first bin inclusive + assert!(sql.contains("\"col\" > 10 AND \"col\" <= 20")); } #[test] @@ -1156,8 +1157,8 @@ mod tests { sql ); assert!( - sql.contains("value >= 0"), - "SQL should use raw column name. Got: {}", + sql.contains("\"value\" >= 0"), + "SQL should use quoted column name. Got: {}", sql ); assert!( @@ -1192,7 +1193,7 @@ mod tests { !sql.contains("CAST("), "SQL should not contain CAST when column is numeric" ); - assert!(sql.contains("value >= 0"), "SQL should use raw column name"); + assert!(sql.contains("\"value\" >= 0"), "SQL should use quoted column name"); } #[test] @@ -1468,9 +1469,9 @@ mod tests { "left", vec![0.0, 10.0, 20.0, 30.0], vec![ - "WHEN value < 10 THEN 5", // First bin extends to -∞ - "WHEN value >= 10 AND value < 20 THEN 15", // Middle bin - "WHEN value >= 20 THEN 25", // Last bin extends to +∞ + "WHEN \"value\" < 10 THEN 5", // First bin extends to -∞ + "WHEN \"value\" >= 10 AND \"value\" < 20 THEN 15", // Middle bin + "WHEN \"value\" >= 20 THEN 25", // Last bin extends to +∞ ], ), // closed="right" with 3 bins (4 breaks) @@ -1478,9 +1479,9 @@ mod tests { "right", vec![0.0, 10.0, 20.0, 30.0], vec![ - "WHEN value <= 10 THEN 5", // First bin extends to -∞ - "WHEN value > 10 AND value <= 20 THEN 15", // Middle bin - "WHEN value > 20 THEN 25", // Last bin extends to +∞ + "WHEN \"value\" <= 10 THEN 5", // First bin extends to -∞ + "WHEN \"value\" > 10 AND \"value\" <= 20 THEN 15", // Middle bin + "WHEN \"value\" > 20 THEN 25", // Last bin extends to +∞ ], ), ]; @@ -1541,11 +1542,11 @@ mod tests { .pre_stat_transform_sql("x", &DataType::Float64, &scale, &AnsiDialect) .unwrap(); assert!( - sql.contains("WHEN x < 50 THEN 25"), + sql.contains("WHEN \"x\" < 50 THEN 25"), "Two bins: first should extend to -∞" ); assert!( - sql.contains("WHEN x >= 50 THEN 75"), + sql.contains("WHEN \"x\" >= 50 THEN 75"), "Two bins: last should extend to +∞" ); } @@ -1590,11 +1591,11 @@ mod tests { .pre_stat_transform_sql("x", &DataType::Float64, &scale, &AnsiDialect) .unwrap(); assert!( - sql.contains("x >= 0 AND x < 10"), + sql.contains("\"x\" >= 0 AND \"x\" < 10"), "First bin should have lower bound with censor" ); assert!( - sql.contains("x >= 10 AND x <= 20"), + sql.contains("\"x\" >= 10 AND \"x\" <= 20"), "Last bin should have upper bound with censor" ); } @@ -1607,14 +1608,14 @@ mod tests { ( true, vec![ - "WHEN col < 10 THEN 5", - "WHEN col >= 10 AND col < 20 THEN 15", - "WHEN col >= 20 THEN 25", + "WHEN \"col\" < 10 THEN 5", + "WHEN \"col\" >= 10 AND \"col\" < 20 THEN 15", + "WHEN \"col\" >= 20 THEN 25", ], ), ( false, - vec!["col >= 0 AND col < 10", "col >= 10 AND col <= 20"], + vec!["\"col\" >= 0 AND \"col\" < 10", "\"col\" >= 10 AND \"col\" <= 20"], ), ]; diff --git a/src/plot/scale/scale_type/continuous.rs b/src/plot/scale/scale_type/continuous.rs index 6ccf750f..fce0a28e 100644 --- a/src/plot/scale/scale_type/continuous.rs +++ b/src/plot/scale/scale_type/continuous.rs @@ -192,14 +192,18 @@ impl ScaleTypeTrait for Continuous { .unwrap_or(super::default_oob(&scale.aesthetic)); match oob { - OOB_CENSOR => Some(format!( - "(CASE WHEN {} >= {} AND {} <= {} THEN {} ELSE NULL END)", - column_name, min, column_name, max, column_name - )), + OOB_CENSOR => { + let quoted = format!("\"{}\"", column_name); + Some(format!( + "(CASE WHEN {} >= {} AND {} <= {} THEN {} ELSE NULL END)", + quoted, min, quoted, max, quoted + )) + } OOB_SQUISH => { let min_s = min.to_string(); let max_s = max.to_string(); - let inner = dialect.sql_least(&[&max_s, column_name]); + let quoted = format!("\"{}\"", column_name); + let inner = dialect.sql_least(&[&max_s, "ed]); Some(dialect.sql_greatest(&[&min_s, &inner])) } _ => None, // "keep" = no transformation @@ -237,8 +241,8 @@ mod tests { let sql = sql.unwrap(); // Should generate CASE WHEN for censor assert!(sql.contains("CASE WHEN")); - assert!(sql.contains("value >= 0")); - assert!(sql.contains("value <= 100")); + assert!(sql.contains("\"value\" >= 0")); + assert!(sql.contains("\"value\" <= 100")); assert!(sql.contains("ELSE NULL")); } diff --git a/src/plot/scale/scale_type/discrete.rs b/src/plot/scale/scale_type/discrete.rs index a81686ce..14c1412f 100644 --- a/src/plot/scale/scale_type/discrete.rs +++ b/src/plot/scale/scale_type/discrete.rs @@ -261,11 +261,12 @@ impl ScaleTypeTrait for Discrete { } // Always censor - discrete scales have no other valid OOB behavior + let quoted = format!("\"{}\"", column_name); Some(format!( "(CASE WHEN {} IN ({}) THEN {} ELSE NULL END)", - column_name, + quoted, allowed_values.join(", "), - column_name + quoted )) } } diff --git a/src/plot/scale/scale_type/ordinal.rs b/src/plot/scale/scale_type/ordinal.rs index a315af14..793f9126 100644 --- a/src/plot/scale/scale_type/ordinal.rs +++ b/src/plot/scale/scale_type/ordinal.rs @@ -293,11 +293,12 @@ impl ScaleTypeTrait for Ordinal { } // Always censor - ordinal scales have no other valid OOB behavior + let quoted = format!("\"{}\"", column_name); Some(format!( "(CASE WHEN {} IN ({}) THEN {} ELSE NULL END)", - column_name, + quoted, allowed_values.join(", "), - column_name + quoted )) } } diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index 3ac23b24..1cf0de55 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -44,14 +44,16 @@ impl super::SqlDialect for DuckDbDialect { fn sql_percentile(&self, column: &str, fraction: f64, from: &str, groups: &[String]) -> String { let group_filter = groups .iter() - .map(|g| format!("AND \"__ggsql_pct__\".{g} IS NOT DISTINCT FROM \"__ggsql_qt__\".{g}")) + .map(|g| format!("AND \"__ggsql_pct__\".\"{g}\" IS NOT DISTINCT FROM \"__ggsql_qt__\".\"{g}\"")) .collect::>() .join(" "); + let quoted_column = format!("\"{}\"", column); format!( "(SELECT QUANTILE_CONT({column}, {fraction}) \ FROM ({from}) AS \"__ggsql_pct__\" \ - WHERE {column} IS NOT NULL {group_filter})" + WHERE {column} IS NOT NULL {group_filter})", + column = quoted_column ) } } diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 512be7eb..020ac5d3 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -186,12 +186,13 @@ pub trait SqlDialect { // Uses NTILE(4) to divide data into quartiles, then interpolates between boundaries. let group_filter = groups .iter() - .map(|g| format!("AND \"__ggsql_pct__\".{g} IS NOT DISTINCT FROM \"__ggsql_qt__\".{g}")) + .map(|g| format!("AND \"__ggsql_pct__\".\"{g}\" IS NOT DISTINCT FROM \"__ggsql_qt__\".\"{g}\"")) .collect::>() .join(" "); let lo_tile = (fraction * 4.0).ceil() as usize; let hi_tile = lo_tile + 1; + let quoted_column = format!("\"{}\"", column); format!( "(SELECT (\ @@ -203,7 +204,8 @@ pub trait SqlDialect { NTILE(4) OVER (ORDER BY {column}) AS __tile \ FROM ({from}) AS \"__ggsql_pct__\" \ WHERE {column} IS NOT NULL {group_filter}\ - ))" + ))", + column = quoted_column ) } } From f59490c9a4add7578852f0677826ade8b0d8c884 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Fri, 20 Mar 2026 11:16:01 +0000 Subject: [PATCH 08/20] WIP: Quoting --- src/plot/layer/geom/boxplot.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/plot/layer/geom/boxplot.rs b/src/plot/layer/geom/boxplot.rs index b87b2a21..2534422d 100644 --- a/src/plot/layer/geom/boxplot.rs +++ b/src/plot/layer/geom/boxplot.rs @@ -254,13 +254,13 @@ fn boxplot_sql_append_outliers( // Each row type maps to one visual element with y and yend where needed let build_summary_select = |table: &str| { format!( - "SELECT {groups}, 'lower_whisker' AS {type_name}, q1 AS {value_name}, lower AS {value2_name} FROM {table} + "SELECT {groups}, 'lower_whisker' AS \"{type_name}\", q1 AS \"{value_name}\", lower AS \"{value2_name}\" FROM {table} UNION ALL - SELECT {groups}, 'upper_whisker' AS {type_name}, q3 AS {value_name}, upper AS {value2_name} FROM {table} + SELECT {groups}, 'upper_whisker' AS \"{type_name}\", q3 AS \"{value_name}\", upper AS \"{value2_name}\" FROM {table} UNION ALL - SELECT {groups}, 'box' AS {type_name}, q1 AS {value_name}, q3 AS {value2_name} FROM {table} + SELECT {groups}, 'box' AS \"{type_name}\", q1 AS \"{value_name}\", q3 AS \"{value2_name}\" FROM {table} UNION ALL - SELECT {groups}, 'median' AS {type_name}, median AS {value_name}, NULL AS {value2_name} FROM {table}", + SELECT {groups}, 'median' AS \"{type_name}\", median AS \"{value_name}\", NULL AS \"{value2_name}\" FROM {table}", groups = groups_str, type_name = type_name, value_name = value_name, @@ -291,7 +291,7 @@ fn boxplot_sql_append_outliers( ) {summary_select} UNION ALL - SELECT {groups}, type AS {type_name}, value AS {value_name}, NULL AS {value2_name} + SELECT {groups}, type AS \"{type_name}\", value AS \"{value_name}\", NULL AS \"{value2_name}\" FROM outliers ", summary = from, @@ -470,9 +470,9 @@ mod tests { assert!(result.contains("'median'")); // Check column names - assert!(result.contains(&format!("AS {}", naming::stat_column("value")))); - assert!(result.contains(&format!("AS {}", naming::stat_column("value2")))); - assert!(result.contains(&format!("AS {}", naming::stat_column("type")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("value")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("value2")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("type")))); } #[test] @@ -494,9 +494,9 @@ mod tests { assert!(result.contains("'median'")); // Check column names - assert!(result.contains(&format!("AS {}", naming::stat_column("value")))); - assert!(result.contains(&format!("AS {}", naming::stat_column("value2")))); - assert!(result.contains(&format!("AS {}", naming::stat_column("type")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("value")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("value2")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("type")))); } #[test] From c2e89760fb05f38ca1002621260160af0c5ce02f Mon Sep 17 00:00:00 2001 From: George Stagg Date: Tue, 31 Mar 2026 09:35:51 +0100 Subject: [PATCH 09/20] sqlite in executor.rs --- ggsql-jupyter/Cargo.toml | 1 + ggsql-jupyter/src/executor.rs | 58 ++++++++++++++++++++--------------- src/Cargo.toml | 2 +- 3 files changed, 36 insertions(+), 25 deletions(-) diff --git a/ggsql-jupyter/Cargo.toml b/ggsql-jupyter/Cargo.toml index 6f4bb1f0..330dbe91 100644 --- a/ggsql-jupyter/Cargo.toml +++ b/ggsql-jupyter/Cargo.toml @@ -59,6 +59,7 @@ uuid = { version = "1.0", features = ["v4"] } [features] default = [] odbc = ["ggsql/odbc"] +sqlite = ["ggsql/sqlite"] [dev-dependencies] # Test utilities diff --git a/ggsql-jupyter/src/executor.rs b/ggsql-jupyter/src/executor.rs index 9eb3448a..434345c3 100644 --- a/ggsql-jupyter/src/executor.rs +++ b/ggsql-jupyter/src/executor.rs @@ -22,16 +22,14 @@ pub enum ExecutionResult { spec: String, // Vega-Lite JSON }, /// Connection changed via meta-command - ConnectionChanged { - uri: String, - display_name: String, - }, + ConnectionChanged { uri: String, display_name: String }, } /// Create a reader from a connection URI string. /// /// Supported schemes: /// - `duckdb://memory` or `duckdb://` (always available) +/// - `sqlite://` (requires `sqlite` feature) /// - `odbc://...` (requires `odbc` feature) pub fn create_reader(uri: &str) -> Result> { use ggsql::reader::connection::ConnectionInfo; @@ -43,22 +41,22 @@ pub fn create_reader(uri: &str) -> Result> { Ok(Box::new(reader)) } ConnectionInfo::DuckDBFile(path) => { - let reader = - DuckDBReader::from_connection_string(&format!("duckdb://{}", path))?; + let reader = DuckDBReader::from_connection_string(&format!("duckdb://{}", path))?; Ok(Box::new(reader)) } #[cfg(feature = "odbc")] ConnectionInfo::ODBC(conn_str) => { - let reader = ggsql::reader::OdbcReader::from_connection_string( - &format!("odbc://{}", conn_str), - )?; + let reader = + ggsql::reader::OdbcReader::from_connection_string(&format!("odbc://{}", conn_str))?; Ok(Box::new(reader)) } - _ => anyhow::bail!( - "Unsupported reader type for connection string: {}. \ - Only DuckDB connections are currently supported in ggsql-jupyter.", - uri - ), + #[cfg(feature = "sqlite")] + ConnectionInfo::SQLite(path) => { + let reader = + ggsql::reader::SqliteReader::from_connection_string(&format!("sqlite://{}", path))?; + Ok(Box::new(reader)) + } + _ => anyhow::bail!("Unsupported reader type for connection string: {}", uri), } } @@ -70,6 +68,12 @@ pub fn display_name_for_uri(uri: &str) -> String { if let Some(path) = uri.strip_prefix("duckdb://") { return format!("DuckDB ({})", path); } + if let Some(path) = uri.strip_prefix("sqlite://") { + if path.is_empty() { + return "SQLite (memory)".to_string(); + } + return format!("SQLite ({})", path); + } if let Some(odbc) = uri.strip_prefix("odbc://") { // Try to extract driver name from ODBC string if let Some(driver_start) = odbc.to_lowercase().find("driver=") { @@ -91,11 +95,16 @@ pub fn type_name_for_uri(uri: &str) -> String { if uri.starts_with("duckdb://") { return "DuckDB".to_string(); } + if uri.starts_with("sqlite://") { + return "SQLite".to_string(); + } if let Some(odbc) = uri.strip_prefix("odbc://") { if odbc.to_lowercase().contains("driver=snowflake") { return "Snowflake".to_string(); } - if odbc.to_lowercase().contains("driver={postgresql}") || odbc.to_lowercase().contains("driver=postgresql") { + if odbc.to_lowercase().contains("driver={postgresql}") + || odbc.to_lowercase().contains("driver=postgresql") + { return "PostgreSQL".to_string(); } return "ODBC".to_string(); @@ -111,6 +120,12 @@ pub fn host_for_uri(uri: &str) -> String { if let Some(path) = uri.strip_prefix("duckdb://") { return path.to_string(); } + if let Some(path) = uri.strip_prefix("sqlite://") { + if path.is_empty() { + return "memory".to_string(); + } + return path.to_string(); + } if let Some(odbc) = uri.strip_prefix("odbc://") { // Try to extract server if let Some(server_start) = odbc.to_lowercase().find("server=") { @@ -129,11 +144,9 @@ const META_CONNECT_PREFIX: &str = "-- @connect:"; /// Parse a `-- @connect: ` meta-command, returning the URI if present. pub fn parse_meta_command(code: &str) -> Option { let trimmed = code.trim(); - if let Some(rest) = trimmed.strip_prefix(META_CONNECT_PREFIX) { - Some(rest.trim().to_string()) - } else { - None - } + trimmed + .strip_prefix(META_CONNECT_PREFIX) + .map(|rest| rest.trim().to_string()) } /// Query executor maintaining persistent database connection @@ -195,10 +208,7 @@ impl QueryExecutor { tracing::info!("Meta-command: switching reader to {}", uri); self.swap_reader(&uri)?; let display_name = display_name_for_uri(&uri); - return Ok(ExecutionResult::ConnectionChanged { - uri, - display_name, - }); + return Ok(ExecutionResult::ConnectionChanged { uri, display_name }); } // 1. Validate to check if there's a visualization diff --git a/src/Cargo.toml b/src/Cargo.toml index aca3ffd8..d8462fa0 100644 --- a/src/Cargo.toml +++ b/src/Cargo.toml @@ -79,7 +79,7 @@ tempfile = "3.8" ureq = "3" [features] -default = ["duckdb", "sqlite", "vegalite", "ipc", "parquet", "builtin-data"] +default = ["duckdb", "sqlite", "vegalite", "ipc", "parquet", "builtin-data", "odbc"] ipc = ["polars/ipc"] duckdb = ["dep:duckdb", "dep:arrow"] parquet = ["polars/parquet"] From 46a4bc3a04ff27a3019c91e580abf173cd02bc37 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Tue, 31 Mar 2026 09:59:21 +0100 Subject: [PATCH 10/20] Use general SqlDialect in ODBC reader --- ggsql-jupyter/Cargo.toml | 3 +- src/reader/mod.rs | 3 + src/reader/odbc.rs | 180 +++++++++++---------------------------- src/reader/snowflake.rs | 32 +++++++ 4 files changed, 87 insertions(+), 131 deletions(-) create mode 100644 src/reader/snowflake.rs diff --git a/ggsql-jupyter/Cargo.toml b/ggsql-jupyter/Cargo.toml index 330dbe91..eadf493a 100644 --- a/ggsql-jupyter/Cargo.toml +++ b/ggsql-jupyter/Cargo.toml @@ -57,7 +57,8 @@ hex = "0.4" uuid = { version = "1.0", features = ["v4"] } [features] -default = [] +default = ["all-readers"] +all-readers = ["sqlite", "odbc"] odbc = ["ggsql/odbc"] sqlite = ["ggsql/sqlite"] diff --git a/src/reader/mod.rs b/src/reader/mod.rs index bc585087..e71bd0b0 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -257,6 +257,9 @@ pub mod sqlite; #[cfg(feature = "odbc")] pub mod odbc; +#[cfg(feature = "odbc")] +pub mod snowflake; + pub mod connection; pub mod data; mod spec; diff --git a/src/reader/odbc.rs b/src/reader/odbc.rs index cb3f06c0..d6f47742 100644 --- a/src/reader/odbc.rs +++ b/src/reader/odbc.rs @@ -17,101 +17,41 @@ fn odbc_env() -> &'static Environment { ENV.get_or_init(|| Environment::new().expect("Failed to create ODBC environment")) } -/// ODBC SQL dialect. +/// Detect the backend SQL dialect from an ODBC connection string. /// -/// Uses ANSI SQL by default. The `variant` field can be used to detect -/// specific backends for dialect customization. -pub struct OdbcDialect { - #[allow(dead_code)] - variant: OdbcVariant, -} - -/// Detected ODBC backend variant. -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum OdbcVariant { - Generic, - Snowflake, - PostgreSQL, - SqlServer, -} - -impl super::SqlDialect for OdbcDialect { - fn sql_list_catalogs(&self) -> String { - match self.variant { - OdbcVariant::Snowflake => "SHOW DATABASES".into(), - _ => { - "SELECT DISTINCT catalog_name FROM information_schema.schemata \ - ORDER BY catalog_name" - .into() - } +/// Returns a dialect matching the detected backend (e.g. Snowflake, SQLite, +/// DuckDB, or ANSI for generic/unknown backends). +fn detect_dialect(conn_str: &str) -> Box { + let lower = conn_str.to_lowercase(); + if lower.contains("driver=snowflake") { + Box::new(super::snowflake::SnowflakeDialect) + } else if lower.contains("driver=sqlite") || lower.contains("driver={sqlite") { + #[cfg(feature = "sqlite")] + { + Box::new(super::sqlite::SqliteDialect) } - } - - fn sql_list_schemas(&self, catalog: &str) -> String { - match self.variant { - OdbcVariant::Snowflake => { - let catalog_ident = catalog.replace('"', "\"\""); - format!("SHOW SCHEMAS IN DATABASE \"{catalog_ident}\"") - } - _ => { - let catalog = catalog.replace('\'', "''"); - format!( - "SELECT DISTINCT schema_name FROM information_schema.schemata \ - WHERE catalog_name = '{catalog}' ORDER BY schema_name" - ) - } + #[cfg(not(feature = "sqlite"))] + { + Box::new(super::AnsiDialect) } - } - - fn sql_list_tables(&self, catalog: &str, schema: &str) -> String { - match self.variant { - OdbcVariant::Snowflake => { - let catalog_ident = catalog.replace('"', "\"\""); - let schema_ident = schema.replace('"', "\"\""); - format!( - "SHOW OBJECTS IN SCHEMA \"{catalog_ident}\".\"{schema_ident}\"" - ) - } - _ => { - let catalog = catalog.replace('\'', "''"); - let schema = schema.replace('\'', "''"); - format!( - "SELECT DISTINCT table_name, table_type FROM information_schema.tables \ - WHERE table_catalog = '{catalog}' AND table_schema = '{schema}' \ - ORDER BY table_name" - ) - } + } else if lower.contains("driver=duckdb") || lower.contains("driver={duckdb") { + #[cfg(feature = "duckdb")] + { + Box::new(super::duckdb::DuckDbDialect) } - } - - fn sql_list_columns(&self, catalog: &str, schema: &str, table: &str) -> String { - match self.variant { - OdbcVariant::Snowflake => { - let catalog_ident = catalog.replace('"', "\"\""); - let schema_ident = schema.replace('"', "\"\""); - let table_ident = table.replace('"', "\"\""); - format!( - "SHOW COLUMNS IN TABLE \"{catalog_ident}\".\"{schema_ident}\".\"{table_ident}\"" - ) - } - _ => { - let catalog = catalog.replace('\'', "''"); - let schema = schema.replace('\'', "''"); - let table = table.replace('\'', "''"); - format!( - "SELECT column_name, data_type FROM information_schema.columns \ - WHERE table_catalog = '{catalog}' AND table_schema = '{schema}' \ - AND table_name = '{table}' ORDER BY ordinal_position" - ) - } + #[cfg(not(feature = "duckdb"))] + { + Box::new(super::AnsiDialect) } + } else { + Box::new(super::AnsiDialect) } } /// Generic ODBC reader implementing the `Reader` trait. pub struct OdbcReader { connection: odbc_api::Connection<'static>, - dialect: OdbcDialect, + dialect: Box, registered_tables: RefCell>, } @@ -147,8 +87,8 @@ impl OdbcReader { } } - // Detect variant from connection string - let variant = detect_variant(&conn_str); + // Detect backend dialect from connection string + let dialect = detect_dialect(&conn_str); let env = odbc_env(); let connection = env @@ -157,7 +97,7 @@ impl OdbcReader { Ok(Self { connection, - dialect: OdbcDialect { variant }, + dialect, registered_tables: RefCell::new(HashSet::new()), }) } @@ -326,7 +266,7 @@ impl Reader for OdbcReader { } fn dialect(&self) -> &dyn super::SqlDialect { - &self.dialect + &*self.dialect } } @@ -457,23 +397,12 @@ fn has_token(conn_str: &str) -> bool { conn_str.to_lowercase().contains("token=") } -fn detect_variant(conn_str: &str) -> OdbcVariant { - let lower = conn_str.to_lowercase(); - if lower.contains("driver=snowflake") { - OdbcVariant::Snowflake - } else if lower.contains("driver={postgresql}") || lower.contains("driver=postgresql") { - OdbcVariant::PostgreSQL - } else if lower.contains("driver={odbc driver") || lower.contains("driver={sql server") { - OdbcVariant::SqlServer - } else { - OdbcVariant::Generic - } -} - fn home_dir() -> Option { #[cfg(target_os = "windows")] { - std::env::var("USERPROFILE").ok().map(std::path::PathBuf::from) + std::env::var("USERPROFILE") + .ok() + .map(std::path::PathBuf::from) } #[cfg(not(target_os = "windows"))] { @@ -505,8 +434,7 @@ fn find_snowflake_connections_toml() -> Option { if let Some(home) = home_dir() { #[cfg(target_os = "macos")] { - let p = home - .join("Library/Application Support/snowflake/connections.toml"); + let p = home.join("Library/Application Support/snowflake/connections.toml"); if p.exists() { return Some(p); } @@ -525,8 +453,7 @@ fn find_snowflake_connections_toml() -> Option { #[cfg(target_os = "windows")] { - let p = home - .join("AppData/Local/snowflake/connections.toml"); + let p = home.join("AppData/Local/snowflake/connections.toml"); if p.exists() { return Some(p); } @@ -565,9 +492,7 @@ fn resolve_connection_name(conn_str: &str) -> Option { } // Build ODBC connection string from TOML entry fields - let get_str = |key: &str| -> Option { - entry.get(key)?.as_str().map(|s| s.to_string()) - }; + let get_str = |key: &str| -> Option { entry.get(key)?.as_str().map(|s| s.to_string()) }; let account = get_str("account")?; let mut parts = vec![ @@ -656,19 +581,18 @@ mod tests { } #[test] - fn test_detect_variant() { - assert_eq!( - detect_variant("Driver=Snowflake;Server=foo"), - OdbcVariant::Snowflake - ); - assert_eq!( - detect_variant("Driver={PostgreSQL};Server=localhost"), - OdbcVariant::PostgreSQL - ); - assert_eq!( - detect_variant("Driver=SomeOther;Server=localhost"), - OdbcVariant::Generic - ); + fn test_detect_dialect() { + // Snowflake uses SHOW commands + let dialect = detect_dialect("Driver=Snowflake;Server=foo"); + assert!(dialect.sql_list_catalogs().contains("SHOW")); + + // PostgreSQL uses information_schema (ANSI default) + let dialect = detect_dialect("Driver={PostgreSQL};Server=localhost"); + assert!(dialect.sql_list_catalogs().contains("information_schema")); + + // Generic uses information_schema (ANSI default) + let dialect = detect_dialect("Driver=SomeOther;Server=localhost"); + assert!(dialect.sql_list_catalogs().contains("information_schema")); } #[test] @@ -712,8 +636,7 @@ account = "otheraccount" // Point SNOWFLAKE_HOME at our temp dir std::env::set_var("SNOWFLAKE_HOME", dir.path()); - let result = - resolve_connection_name("Driver=Snowflake;ConnectionName=myconn"); + let result = resolve_connection_name("Driver=Snowflake;ConnectionName=myconn"); assert!(result.is_some()); let conn = result.unwrap(); assert!(conn.contains("Driver=Snowflake")); @@ -726,21 +649,18 @@ account = "otheraccount" assert!(conn.contains("Role=myrole")); // Test with a connection that has fewer fields - let result2 = - resolve_connection_name("Driver=Snowflake;ConnectionName=other"); + let result2 = resolve_connection_name("Driver=Snowflake;ConnectionName=other"); assert!(result2.is_some()); let conn2 = result2.unwrap(); assert!(conn2.contains("Server=otheraccount.snowflakecomputing.com")); assert!(!conn2.contains("UID=")); // Test with non-existent connection name - let result3 = - resolve_connection_name("Driver=Snowflake;ConnectionName=nonexistent"); + let result3 = resolve_connection_name("Driver=Snowflake;ConnectionName=nonexistent"); assert!(result3.is_none()); // No ConnectionName param → None - let result4 = - resolve_connection_name("Driver=Snowflake;Server=foo"); + let result4 = resolve_connection_name("Driver=Snowflake;Server=foo"); assert!(result4.is_none()); // Clean up env diff --git a/src/reader/snowflake.rs b/src/reader/snowflake.rs new file mode 100644 index 00000000..8752b8ad --- /dev/null +++ b/src/reader/snowflake.rs @@ -0,0 +1,32 @@ +//! Snowflake-specific SQL dialect. +//! +//! Overrides schema introspection to use Snowflake's SHOW commands +//! instead of information_schema queries. + +pub struct SnowflakeDialect; + +impl super::SqlDialect for SnowflakeDialect { + fn sql_list_catalogs(&self) -> String { + "SHOW DATABASES".into() + } + + fn sql_list_schemas(&self, catalog: &str) -> String { + let catalog_ident = catalog.replace('"', "\"\""); + format!("SHOW SCHEMAS IN DATABASE \"{catalog_ident}\"") + } + + fn sql_list_tables(&self, catalog: &str, schema: &str) -> String { + let catalog_ident = catalog.replace('"', "\"\""); + let schema_ident = schema.replace('"', "\"\""); + format!("SHOW OBJECTS IN SCHEMA \"{catalog_ident}\".\"{schema_ident}\"") + } + + fn sql_list_columns(&self, catalog: &str, schema: &str, table: &str) -> String { + let catalog_ident = catalog.replace('"', "\"\""); + let schema_ident = schema.replace('"', "\"\""); + let table_ident = table.replace('"', "\"\""); + format!( + "SHOW COLUMNS IN TABLE \"{catalog_ident}\".\"{schema_ident}\".\"{table_ident}\"" + ) + } +} From 35885bcab45685dde77f6e93d9e98139325bbb42 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Tue, 31 Mar 2026 12:49:22 +0100 Subject: [PATCH 11/20] Data explorer integration --- ggsql-jupyter/src/data_explorer.rs | 695 +++++++++++++++++++++++++++++ ggsql-jupyter/src/kernel.rs | 123 ++++- ggsql-jupyter/src/lib.rs | 1 + ggsql-jupyter/src/main.rs | 1 + 4 files changed, 819 insertions(+), 1 deletion(-) create mode 100644 ggsql-jupyter/src/data_explorer.rs diff --git a/ggsql-jupyter/src/data_explorer.rs b/ggsql-jupyter/src/data_explorer.rs new file mode 100644 index 00000000..56e5568e --- /dev/null +++ b/ggsql-jupyter/src/data_explorer.rs @@ -0,0 +1,695 @@ +//! Data explorer backend for the Positron data viewer. +//! +//! Implements the `positron.dataExplorer` comm protocol, providing SQL-backed +//! paginated data access. No full table load — each `get_data_values` request +//! issues a `SELECT ... LIMIT/OFFSET` query. + +use ggsql::reader::Reader; +use serde_json::{json, Value}; + +/// Result of handling an RPC call. +pub struct RpcResponse { + /// The JSON-RPC result to send as the reply. + pub result: Value, + /// An optional event to send on iopub (e.g. `return_column_profiles`). + pub event: Option, +} + +/// An asynchronous event to send back on the comm after the RPC reply. +pub struct RpcEvent { + pub method: String, + pub params: Value, +} + +impl RpcResponse { + /// Create a simple reply with no async event. + pub fn reply(result: Value) -> Self { + Self { + result, + event: None, + } + } +} + +/// Cached column metadata for a table. +#[derive(Debug, Clone)] +pub struct ColumnInfo { + pub name: String, + /// Backend-specific type name (e.g. "INTEGER", "VARCHAR"). + pub type_name: String, + /// Positron display type (e.g. "integer", "string"). + pub type_display: String, +} + +/// State for one open data explorer comm. +pub struct DataExplorerState { + /// Fully qualified and quoted table path, e.g. `"memory"."main"."users"`. + table_path: String, + /// Display title shown in the data viewer tab. + title: String, + /// Cached column schemas. + columns: Vec, + /// Cached total row count. + num_rows: usize, +} + +impl DataExplorerState { + /// Open a data explorer for a table at the given connection path. + /// + /// Runs `SELECT COUNT(*)` and a column metadata query to cache schema + /// information. Does **not** load the full table into memory. + pub fn open(reader: &dyn Reader, path: &[String]) -> Result { + if path.len() < 3 { + return Err(format!( + "Expected [catalog, schema, table] path, got {} elements", + path.len() + )); + } + + let catalog = &path[0]; + let schema = &path[1]; + let table = &path[2]; + + let table_path = format!( + "\"{}\".\"{}\".\"{}\"" , + catalog.replace('"', "\"\""), + schema.replace('"', "\"\""), + table.replace('"', "\"\""), + ); + + // Get row count + let count_sql = format!("SELECT COUNT(*) AS n FROM {}", table_path); + let count_df = reader + .execute_sql(&count_sql) + .map_err(|e| format!("Failed to count rows: {}", e))?; + let num_rows = count_df + .column("n") + .ok() + .and_then(|col| col.get(0).ok()) + .and_then(|val| { + // Polars AnyValue — try common integer representations + let s = format!("{}", val); + s.parse::().ok() + }) + .unwrap_or(0); + + // Get column metadata from information_schema + let columns_sql = reader.dialect().sql_list_columns(catalog, schema, table); + let columns_df = reader + .execute_sql(&columns_sql) + .map_err(|e| format!("Failed to list columns: {}", e))?; + + let name_col = columns_df + .column("column_name") + .map_err(|e| format!("Missing column_name: {}", e))?; + let type_col = columns_df + .column("data_type") + .map_err(|e| format!("Missing data_type: {}", e))?; + + let mut columns = Vec::new(); + for i in 0..columns_df.height() { + if let (Ok(name_val), Ok(type_val)) = (name_col.get(i), type_col.get(i)) { + let name = name_val.to_string().trim_matches('"').to_string(); + let type_name = type_val.to_string().trim_matches('"').to_string(); + let type_display = sql_type_to_display(&type_name).to_string(); + columns.push(ColumnInfo { + name, + type_name, + type_display, + }); + } + } + + Ok(Self { + table_path, + title: table.clone(), + columns, + num_rows, + }) + } + + /// Dispatch a JSON-RPC method call. + /// + /// Returns the RPC result and an optional async event to send on iopub + /// (used by `get_column_profiles` to deliver results asynchronously). + pub fn handle_rpc(&self, method: &str, params: &Value, reader: &dyn Reader) -> RpcResponse { + match method { + "get_state" => RpcResponse::reply(self.get_state()), + "get_schema" => RpcResponse::reply(self.get_schema(params)), + "get_data_values" => RpcResponse::reply(self.get_data_values(params, reader)), + "get_column_profiles" => self.get_column_profiles(params, reader), + "set_row_filters" => { + // Stub: accept but ignore filters, return current shape + RpcResponse::reply(json!({ + "selected_num_rows": self.num_rows, + "had_errors": false + })) + } + "set_sort_columns" | "set_column_filters" | "search_schema" => { + RpcResponse::reply(json!(null)) + } + _ => { + tracing::warn!("Unhandled data explorer method: {}", method); + RpcResponse::reply(json!(null)) + } + } + } + + fn get_state(&self) -> Value { + let num_columns = self.columns.len(); + json!({ + "display_name": self.title, + "table_shape": { + "num_rows": self.num_rows, + "num_columns": num_columns + }, + "table_unfiltered_shape": { + "num_rows": self.num_rows, + "num_columns": num_columns + }, + "has_row_labels": false, + "column_filters": [], + "row_filters": [], + "sort_keys": [], + "supported_features": { + "search_schema": { + "support_status": "unsupported", + "supported_types": [] + }, + "set_column_filters": { + "support_status": "unsupported", + "supported_types": [] + }, + "set_row_filters": { + "support_status": "unsupported", + "supports_conditions": "unsupported", + "supported_types": [] + }, + "get_column_profiles": { + "support_status": "supported", + "supported_types": [ + {"profile_type": "null_count", "support_status": "supported"}, + {"profile_type": "summary_stats", "support_status": "supported"} + ] + }, + "set_sort_columns": { + "support_status": "unsupported" + }, + "export_data_selection": { + "support_status": "unsupported", + "supported_formats": [] + }, + "convert_to_code": { + "support_status": "unsupported" + } + } + }) + } + + fn get_schema(&self, params: &Value) -> Value { + let indices: Vec = params + .get("column_indices") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_u64().map(|n| n as usize)) + .collect() + }) + .unwrap_or_default(); + + let columns: Vec = indices + .iter() + .filter_map(|&idx| { + self.columns.get(idx).map(|col| { + json!({ + "column_name": col.name, + "column_index": idx, + "type_name": col.type_name, + "type_display": col.type_display + }) + }) + }) + .collect(); + + json!({ "columns": columns }) + } + + fn get_data_values(&self, params: &Value, reader: &dyn Reader) -> Value { + let selections = match params.get("columns").and_then(|v| v.as_array()) { + Some(arr) => arr, + None => return json!({ "columns": [] }), + }; + + // Determine the row range from the first selection's spec + let (first_index, last_index) = selections + .first() + .and_then(|sel| sel.get("spec")) + .map(|spec| { + let first = spec + .get("first_index") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + let last = spec + .get("last_index") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + (first, last) + }) + .unwrap_or((0, 0)); + + let limit = last_index.saturating_sub(first_index) + 1; + + // Collect requested column indices + let col_indices: Vec = selections + .iter() + .filter_map(|sel| sel.get("column_index").and_then(|v| v.as_u64()).map(|n| n as usize)) + .collect(); + + // Build column list for SELECT + let col_names: Vec = col_indices + .iter() + .filter_map(|&idx| { + self.columns.get(idx).map(|col| { + format!("\"{}\"", col.name.replace('"', "\"\"")) + }) + }) + .collect(); + + if col_names.is_empty() { + return json!({ "columns": [] }); + } + + let sql = format!( + "SELECT {} FROM {} LIMIT {} OFFSET {}", + col_names.join(", "), + self.table_path, + limit, + first_index, + ); + + let df = match reader.execute_sql(&sql) { + Ok(df) => df, + Err(e) => { + tracing::error!("get_data_values query failed: {}", e); + let empty: Vec> = col_indices.iter().map(|_| vec![]).collect(); + return json!({ "columns": empty }); + } + }; + + // Format each column's values as strings. + // Positron's ColumnValue is `number | string`: numbers are special + // value codes (0 = NULL, 1 = NA, 2 = NaN), strings are formatted data. + const SPECIAL_VALUE_NULL: i64 = 0; + + let columns: Vec> = (0..df.width()) + .map(|col_idx| { + let col = df.get_columns()[col_idx].clone(); + (0..df.height()) + .map(|row_idx| { + match col.get(row_idx) { + Ok(val) => { + if val.is_null() { + json!(SPECIAL_VALUE_NULL) + } else { + let s = format!("{}", val); + // Strip surrounding quotes from string values + let s = s.trim_matches('"'); + Value::String(s.to_string()) + } + } + Err(_) => json!(SPECIAL_VALUE_NULL), + } + }) + .collect() + }) + .collect(); + + json!({ "columns": columns }) + } + + /// Handle `get_column_profiles` — returns `{}` as the RPC result and sends + /// profile data back as an async `return_column_profiles` event. + fn get_column_profiles(&self, params: &Value, reader: &dyn Reader) -> RpcResponse { + let callback_id = params + .get("callback_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let requests = match params.get("profiles").and_then(|v| v.as_array()) { + Some(arr) => arr, + None => { + return RpcResponse { + result: json!({}), + event: Some(RpcEvent { + method: "return_column_profiles".into(), + params: json!({ + "callback_id": callback_id, + "profiles": [] + }), + }), + }; + } + }; + + let mut profiles = Vec::new(); + for req in requests { + let col_idx = req + .get("column_index") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + + let specs = req + .get("profiles") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default(); + + let profile = self.compute_column_profile(col_idx, &specs, reader); + profiles.push(profile); + } + + RpcResponse { + result: json!({}), + event: Some(RpcEvent { + method: "return_column_profiles".into(), + params: json!({ + "callback_id": callback_id, + "profiles": profiles + }), + }), + } + } + + /// Compute profile results for a single column. + fn compute_column_profile( + &self, + col_idx: usize, + specs: &[Value], + reader: &dyn Reader, + ) -> Value { + let col = match self.columns.get(col_idx) { + Some(c) => c, + None => return json!({}), + }; + + let mut wants_null_count = false; + let mut wants_summary = false; + for spec in specs { + match spec + .get("profile_type") + .and_then(|v| v.as_str()) + .unwrap_or("") + { + "null_count" => wants_null_count = true, + "summary_stats" => wants_summary = true, + _ => {} + } + } + + let dialect = reader.dialect(); + let quoted_col = format!("\"{}\"", col.name.replace('"', "\"\"")); + let display = col.type_display.as_str(); + + // Build a single SQL query that computes all needed aggregates. + // All expressions use ANSI SQL or existing dialect methods. + let mut select_parts = Vec::new(); + if wants_null_count { + select_parts.push(format!( + "SUM(CASE WHEN {} IS NULL THEN 1 ELSE 0 END) AS null_count", + quoted_col + )); + } + if wants_summary { + match display { + "integer" | "floating" => { + select_parts.push(format!("MIN({}) AS min_val", quoted_col)); + select_parts.push(format!("MAX({}) AS max_val", quoted_col)); + select_parts.push(format!("AVG(CAST({} AS DOUBLE)) AS mean_val", quoted_col)); + // Stddev: fetch raw aggregates, compute in Rust + select_parts.push(format!( + "SUM(CAST({c} AS DOUBLE) * CAST({c} AS DOUBLE)) AS sum_sq", + c = quoted_col + )); + select_parts.push(format!( + "SUM(CAST({} AS DOUBLE)) AS sum_val", + quoted_col + )); + select_parts.push(format!("COUNT({}) AS cnt", quoted_col)); + } + "boolean" => { + let true_lit = dialect.sql_boolean_literal(true); + let false_lit = dialect.sql_boolean_literal(false); + select_parts.push(format!( + "SUM(CASE WHEN {} = {} THEN 1 ELSE 0 END) AS true_count", + quoted_col, true_lit + )); + select_parts.push(format!( + "SUM(CASE WHEN {} = {} THEN 1 ELSE 0 END) AS false_count", + quoted_col, false_lit + )); + } + "string" => { + select_parts.push(format!("COUNT(DISTINCT {}) AS num_unique", quoted_col)); + select_parts.push(format!( + "SUM(CASE WHEN {} = '' THEN 1 ELSE 0 END) AS num_empty", + quoted_col + )); + } + "date" | "datetime" => { + select_parts.push(format!("MIN({}) AS min_val", quoted_col)); + select_parts.push(format!("MAX({}) AS max_val", quoted_col)); + select_parts.push(format!("COUNT(DISTINCT {}) AS num_unique", quoted_col)); + } + _ => {} + } + } + + if select_parts.is_empty() { + return json!({}); + } + + let sql = format!( + "SELECT {} FROM {}", + select_parts.join(", "), + self.table_path + ); + + let df = match reader.execute_sql(&sql) { + Ok(df) => df, + Err(e) => { + tracing::error!("Column profile query failed: {}", e); + return json!({}); + } + }; + + let get_str = |name: &str| -> Option { + df.column(name) + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| { + if v.is_null() { + None + } else { + Some(format!("{}", v).trim_matches('"').to_string()) + } + }) + }; + + let get_i64 = |name: &str| -> Option { + get_str(name).and_then(|s| s.parse::().ok()) + }; + + let get_f64 = |name: &str| -> Option { + get_str(name).and_then(|s| s.parse::().ok()) + }; + + let mut result = json!({}); + + if wants_null_count { + if let Some(n) = get_i64("null_count") { + result["null_count"] = json!(n); + } + } + + if wants_summary { + let stats = match display { + "integer" | "floating" => { + let mut number_stats = json!({}); + if let Some(v) = get_str("min_val") { + number_stats["min_value"] = json!(v); + } + if let Some(v) = get_str("max_val") { + number_stats["max_value"] = json!(v); + } + if let Some(v) = get_str("mean_val") { + number_stats["mean"] = json!(v); + } + // Compute sample stddev from raw aggregates + if let (Some(sum_sq), Some(sum_val), Some(cnt)) = + (get_f64("sum_sq"), get_f64("sum_val"), get_i64("cnt")) + { + if cnt > 1 { + let variance = + (sum_sq - sum_val * sum_val / cnt as f64) / (cnt - 1) as f64; + let stdev = variance.max(0.0).sqrt(); + number_stats["stdev"] = json!(format!("{}", stdev)); + } + } + // Median via dialect's sql_percentile (uses QUANTILE_CONT on + // DuckDB, NTILE fallback on other backends) + let col_name = col.name.replace('"', "\"\""); + let median_expr = + dialect.sql_percentile(&col_name, 0.5, &self.table_path, &[]); + let median_sql = format!("SELECT {} AS median_val", median_expr); + if let Ok(median_df) = reader.execute_sql(&median_sql) { + if let Some(v) = median_df + .column("median_val") + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| { + if v.is_null() { + None + } else { + Some(format!("{}", v).trim_matches('"').to_string()) + } + }) + { + number_stats["median"] = json!(v); + } + } + json!({ + "type_display": display, + "number_stats": number_stats + }) + } + "boolean" => { + json!({ + "type_display": display, + "boolean_stats": { + "true_count": get_i64("true_count").unwrap_or(0), + "false_count": get_i64("false_count").unwrap_or(0) + } + }) + } + "string" => { + json!({ + "type_display": display, + "string_stats": { + "num_unique": get_i64("num_unique").unwrap_or(0), + "num_empty": get_i64("num_empty").unwrap_or(0) + } + }) + } + "date" => { + let mut date_stats = json!({}); + if let Some(v) = get_str("min_val") { + date_stats["min_date"] = json!(v); + } + if let Some(v) = get_str("max_val") { + date_stats["max_date"] = json!(v); + } + if let Some(n) = get_i64("num_unique") { + date_stats["num_unique"] = json!(n); + } + json!({ + "type_display": display, + "date_stats": date_stats + }) + } + "datetime" => { + let mut datetime_stats = json!({}); + if let Some(v) = get_str("min_val") { + datetime_stats["min_date"] = json!(v); + } + if let Some(v) = get_str("max_val") { + datetime_stats["max_date"] = json!(v); + } + if let Some(n) = get_i64("num_unique") { + datetime_stats["num_unique"] = json!(n); + } + json!({ + "type_display": display, + "datetime_stats": datetime_stats + }) + } + _ => json!({"type_display": display}), + }; + result["summary_stats"] = stats; + } + + result + } +} + +/// Map a SQL type name (from information_schema) to a Positron display type. +fn sql_type_to_display(type_name: &str) -> &'static str { + let upper = type_name.to_uppercase(); + let upper = upper.as_str(); + + if upper.contains("INT") { + return "integer"; + } + if upper.contains("FLOAT") + || upper.contains("DOUBLE") + || upper.contains("REAL") + || upper.contains("NUMERIC") + || upper.contains("DECIMAL") + { + return "floating"; + } + if upper.contains("BOOL") { + return "boolean"; + } + if upper.contains("TIMESTAMP") || upper.contains("DATETIME") { + return "datetime"; + } + if upper.contains("DATE") { + return "date"; + } + if upper.contains("TIME") { + return "time"; + } + if upper.contains("CHAR") + || upper.contains("TEXT") + || upper.contains("STRING") + || upper.contains("VARCHAR") + || upper.contains("CLOB") + { + return "string"; + } + if upper.contains("BLOB") || upper.contains("BINARY") || upper.contains("BYTE") { + return "string"; + } + + "unknown" +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sql_type_to_display() { + assert_eq!(sql_type_to_display("INTEGER"), "integer"); + assert_eq!(sql_type_to_display("BIGINT"), "integer"); + assert_eq!(sql_type_to_display("SMALLINT"), "integer"); + assert_eq!(sql_type_to_display("TINYINT"), "integer"); + assert_eq!(sql_type_to_display("INT"), "integer"); + assert_eq!(sql_type_to_display("DOUBLE"), "floating"); + assert_eq!(sql_type_to_display("FLOAT"), "floating"); + assert_eq!(sql_type_to_display("REAL"), "floating"); + assert_eq!(sql_type_to_display("NUMERIC(10,2)"), "floating"); + assert_eq!(sql_type_to_display("DECIMAL(10,2)"), "floating"); + assert_eq!(sql_type_to_display("BOOLEAN"), "boolean"); + assert_eq!(sql_type_to_display("BOOL"), "boolean"); + assert_eq!(sql_type_to_display("VARCHAR"), "string"); + assert_eq!(sql_type_to_display("TEXT"), "string"); + assert_eq!(sql_type_to_display("DATE"), "date"); + assert_eq!(sql_type_to_display("TIMESTAMP"), "datetime"); + assert_eq!(sql_type_to_display("TIMESTAMP WITH TIME ZONE"), "datetime"); + assert_eq!(sql_type_to_display("TIME"), "time"); + assert_eq!(sql_type_to_display("BLOB"), "string"); + assert_eq!(sql_type_to_display("UNKNOWN_TYPE"), "unknown"); + } +} diff --git a/ggsql-jupyter/src/kernel.rs b/ggsql-jupyter/src/kernel.rs index aa72cf90..4e48ab1a 100644 --- a/ggsql-jupyter/src/kernel.rs +++ b/ggsql-jupyter/src/kernel.rs @@ -4,6 +4,7 @@ //! handling kernel_info, execute, and shutdown requests. use crate::connection; +use crate::data_explorer::{DataExplorerState, RpcResponse}; use crate::display::format_display_data; use crate::executor::{self, ExecutionResult, QueryExecutor}; use crate::message::{ConnectionInfo, JupyterMessage, MessageHeader}; @@ -11,6 +12,7 @@ use anyhow::Result; use hmac::{Hmac, Mac}; use serde_json::{json, Value}; use sha2::Sha256; +use std::collections::HashMap; use zeromq::{PubSocket, RepSocket, RouterSocket, Socket, SocketRecv, SocketSend}; type HmacSha256 = Hmac; @@ -34,6 +36,8 @@ pub struct KernelServer { ui_comm_id: Option, plot_comm_id: Option, connection_comm_id: Option, + // Open data explorer comms (comm_id → state) + data_explorer_comms: HashMap, } impl KernelServer { @@ -95,6 +99,7 @@ impl KernelServer { ui_comm_id: None, plot_comm_id: None, connection_comm_id: None, + data_explorer_comms: HashMap::new(), }; // Send initial "starting" status on IOPub @@ -612,6 +617,11 @@ impl KernelServer { self.handle_connection_rpc(method, rpc_id, comm_id, parent, identities) .await?; } + // Handle positron.dataExplorer requests + else if self.data_explorer_comms.contains_key(comm_id) { + self.handle_data_explorer_rpc(method, rpc_id, comm_id, parent, identities) + .await?; + } // Unknown comm else { tracing::warn!("Message for unknown comm_id: {}", comm_id); @@ -657,6 +667,11 @@ impl KernelServer { comms[id] = json!({"target_name": "positron.connection"}); } } + for id in self.data_explorer_comms.keys() { + if target_name.is_none() || target_name == Some("positron.dataExplorer") { + comms[id] = json!({"target_name": "positron.dataExplorer"}); + } + } tracing::info!( "Returning comms: {}", @@ -703,6 +718,8 @@ impl KernelServer { } else if Some(comm_id.to_string()) == self.connection_comm_id { tracing::info!("Closing positron.connection comm"); self.connection_comm_id = None; + } else if self.data_explorer_comms.remove(comm_id).is_some() { + tracing::info!("Closing data explorer comm: {}", comm_id); } else { tracing::warn!("Close for unknown comm_id: {}", comm_id); } @@ -820,7 +837,55 @@ impl KernelServer { json!(has_data) } "get_icon" => json!(""), - "preview_object" => json!(null), + "preview_object" => { + let path: Vec = params["path"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|v| { + v.get("name") + .and_then(|n| n.as_str()) + .map(|s| s.to_string()) + }) + .collect() + }) + .unwrap_or_default(); + + match DataExplorerState::open(self.executor.reader(), &path) { + Ok(state) => { + let de_comm_id = uuid::Uuid::new_v4().to_string(); + let title = path.last().cloned().unwrap_or_default(); + + // Send comm_open on iopub to open the data viewer + let msg = self.create_message( + "comm_open", + json!({ + "comm_id": de_comm_id, + "target_name": "positron.dataExplorer", + "data": { + "title": title + } + }), + Some(parent), + ); + let zmq_msg = + self.serialize_message_with_topic(&msg, "comm_open")?; + self.iopub.send(zmq_msg).await?; + + tracing::info!( + "Opened data explorer comm: {} for {}", + de_comm_id, + title + ); + self.data_explorer_comms + .insert(de_comm_id, state); + } + Err(e) => { + tracing::error!("preview_object failed: {}", e); + } + } + json!(null) + } "get_metadata" => { let uri = self.executor.reader_uri(); json!({ @@ -855,6 +920,62 @@ impl KernelServer { Ok(()) } + /// Handle JSON-RPC requests on a data explorer comm + async fn handle_data_explorer_rpc( + &mut self, + method: &str, + rpc_id: &Value, + comm_id: &str, + parent: &JupyterMessage, + identities: &[Vec], + ) -> Result<()> { + tracing::info!("Data explorer RPC: {}", method); + + let params = &parent.content["data"]["params"]; + + let RpcResponse { result, event } = + if let Some(state) = self.data_explorer_comms.get(comm_id) { + state.handle_rpc(method, params, self.executor.reader()) + } else { + RpcResponse::reply(json!(null)) + }; + + // Send the RPC reply + self.send_shell_reply( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "id": rpc_id, + "result": result + } + }), + parent, + identities, + ) + .await?; + + // Send async event on iopub if present (e.g. return_column_profiles) + if let Some(evt) = event { + self.send_iopub( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "method": evt.method, + "params": evt.params + } + }), + parent, + ) + .await?; + } + + Ok(()) + } + /// Send a message on the IOPub channel async fn send_iopub( &mut self, diff --git a/ggsql-jupyter/src/lib.rs b/ggsql-jupyter/src/lib.rs index 426f767d..5d2aaf55 100644 --- a/ggsql-jupyter/src/lib.rs +++ b/ggsql-jupyter/src/lib.rs @@ -3,6 +3,7 @@ //! This module exposes the internal components for testing. pub mod connection; +pub mod data_explorer; pub mod display; pub mod executor; pub mod message; diff --git a/ggsql-jupyter/src/main.rs b/ggsql-jupyter/src/main.rs index 9417e791..896ce756 100644 --- a/ggsql-jupyter/src/main.rs +++ b/ggsql-jupyter/src/main.rs @@ -3,6 +3,7 @@ //! A Jupyter kernel for executing ggsql queries with rich Vega-Lite visualizations. mod connection; +mod data_explorer; mod display; mod executor; mod kernel; From 75ad8439710656728f33b6e180f8f85aa487ebda Mon Sep 17 00:00:00 2001 From: George Stagg Date: Tue, 31 Mar 2026 12:52:05 +0100 Subject: [PATCH 12/20] Reply to JSON-RPC even if null --- ggsql-jupyter/src/kernel.rs | 43 +++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/ggsql-jupyter/src/kernel.rs b/ggsql-jupyter/src/kernel.rs index 4e48ab1a..2229048e 100644 --- a/ggsql-jupyter/src/kernel.rs +++ b/ggsql-jupyter/src/kernel.rs @@ -516,6 +516,7 @@ impl KernelServer { self.send_status("busy", parent).await?; // Check if it's a JSON-RPC request + #[allow(clippy::if_same_then_else)] if let Some(method) = data["method"].as_str() { let rpc_id = &data["id"]; @@ -606,11 +607,37 @@ impl KernelServer { } // Handle positron.ui requests else if Some(comm_id.to_string()) == self.ui_comm_id { - tracing::info!("Received UI request: {} (ignoring)", method); + self.send_shell_reply( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "id": rpc_id, + "result": null + } + }), + parent, + identities, + ) + .await?; } // Handle positron.plot requests else if Some(comm_id.to_string()) == self.plot_comm_id { - tracing::info!("Received plot request: {} (ignoring)", method); + self.send_shell_reply( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "id": rpc_id, + "result": null + } + }), + parent, + identities, + ) + .await?; } // Handle positron.connection requests else if Some(comm_id.to_string()) == self.connection_comm_id { @@ -868,17 +895,11 @@ impl KernelServer { }), Some(parent), ); - let zmq_msg = - self.serialize_message_with_topic(&msg, "comm_open")?; + let zmq_msg = self.serialize_message_with_topic(&msg, "comm_open")?; self.iopub.send(zmq_msg).await?; - tracing::info!( - "Opened data explorer comm: {} for {}", - de_comm_id, - title - ); - self.data_explorer_comms - .insert(de_comm_id, state); + tracing::info!("Opened data explorer comm: {} for {}", de_comm_id, title); + self.data_explorer_comms.insert(de_comm_id, state); } Err(e) => { tracing::error!("preview_object failed: {}", e); From 5b9fb7956135e4940d45cad45f855de20beced7c Mon Sep 17 00:00:00 2001 From: George Stagg Date: Tue, 31 Mar 2026 13:09:26 +0100 Subject: [PATCH 13/20] More quoting fun --- src/execute/layer.rs | 2 +- src/plot/layer/geom/rect.rs | 38 +++++++++++++++++------------------ src/plot/layer/geom/smooth.rs | 16 +++++++-------- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/execute/layer.rs b/src/execute/layer.rs index 1ad3eca6..2afd270a 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -54,7 +54,7 @@ pub fn layer_source_query( None => { // Layer uses global data debug_assert!(has_global, "Layer has no source and no global data"); - Ok(format!("SELECT * FROM {}", naming::global_table())) + Ok(format!("SELECT * FROM \"{}\"", naming::global_table())) } } } diff --git a/src/plot/layer/geom/rect.rs b/src/plot/layer/geom/rect.rs index fdd2584a..840276b2 100644 --- a/src/plot/layer/geom/rect.rs +++ b/src/plot/layer/geom/rect.rs @@ -172,8 +172,8 @@ fn process_direction( // Build SELECT parts using the stat columns let select_parts = vec![ - format!("{} AS {}", expr_1, naming::stat_column(&stat_cols[0])), - format!("{} AS {}", expr_2, naming::stat_column(&stat_cols[1])), + format!("{} AS \"{}\"", expr_1, naming::stat_column(&stat_cols[0])), + format!("{} AS \"{}\"", expr_2, naming::stat_column(&stat_cols[1])), ]; Ok((select_parts, stat_cols)) @@ -522,7 +522,7 @@ mod tests { let stat_pos1min = naming::stat_column("pos1min"); let stat_pos1max = naming::stat_column("pos1max"); assert!( - query.contains(&format!("{} AS {}", expected_min, stat_pos1min)), + query.contains(&format!("{} AS \"{}\"", expected_min, stat_pos1min)), "{}: Expected '{} AS {}' in query, got: {}", name, expected_min, @@ -530,7 +530,7 @@ mod tests { query ); assert!( - query.contains(&format!("{} AS {}", expected_max, stat_pos1max)), + query.contains(&format!("{} AS \"{}\"", expected_max, stat_pos1max)), "{}: Expected '{} AS {}' in query, got: {}", name, expected_max, @@ -632,7 +632,7 @@ mod tests { let stat_pos2min = naming::stat_column("pos2min"); let stat_pos2max = naming::stat_column("pos2max"); assert!( - query.contains(&format!("{} AS {}", expected_min, stat_pos2min)), + query.contains(&format!("{} AS \"{}\"", expected_min, stat_pos2min)), "{}: Expected '{} AS {}' in query, got: {}", name, expected_min, @@ -640,7 +640,7 @@ mod tests { query ); assert!( - query.contains(&format!("{} AS {}", expected_max, stat_pos2max)), + query.contains(&format!("{} AS \"{}\"", expected_max, stat_pos2max)), "{}: Expected '{} AS {}' in query, got: {}", name, expected_max, @@ -687,8 +687,8 @@ mod tests { .. }) = result { - assert!(query.contains("__ggsql_aes_pos1__ AS __ggsql_stat_pos1")); - assert!(query.contains("__ggsql_aes_width__ AS __ggsql_stat_width")); + assert!(query.contains("__ggsql_aes_pos1__ AS \"__ggsql_stat_pos1")); + assert!(query.contains("__ggsql_aes_width__ AS \"__ggsql_stat_width")); assert!(stat_columns.contains(&"pos1".to_string())); assert!(stat_columns.contains(&"width".to_string())); assert!(stat_columns.contains(&"pos2min".to_string())); @@ -718,8 +718,8 @@ mod tests { .. }) = result { - assert!(query.contains("__ggsql_aes_pos2__ AS __ggsql_stat_pos2")); - assert!(query.contains("__ggsql_aes_height__ AS __ggsql_stat_height")); + assert!(query.contains("__ggsql_aes_pos2__ AS \"__ggsql_stat_pos2")); + assert!(query.contains("__ggsql_aes_height__ AS \"__ggsql_stat_height")); assert!(stat_columns.contains(&"pos1min".to_string())); assert!(stat_columns.contains(&"pos1max".to_string())); assert!(stat_columns.contains(&"pos2".to_string())); @@ -749,10 +749,10 @@ mod tests { .. }) = result { - assert!(query.contains("__ggsql_aes_pos1__ AS __ggsql_stat_pos1")); - assert!(query.contains("__ggsql_aes_width__ AS __ggsql_stat_width")); - assert!(query.contains("__ggsql_aes_pos2__ AS __ggsql_stat_pos2")); - assert!(query.contains("__ggsql_aes_height__ AS __ggsql_stat_height")); + assert!(query.contains("__ggsql_aes_pos1__ AS \"__ggsql_stat_pos1")); + assert!(query.contains("__ggsql_aes_width__ AS \"__ggsql_stat_width")); + assert!(query.contains("__ggsql_aes_pos2__ AS \"__ggsql_stat_pos2")); + assert!(query.contains("__ggsql_aes_height__ AS \"__ggsql_stat_height")); assert_eq!(stat_columns.len(), 4); } } @@ -852,7 +852,7 @@ mod tests { stat_columns, .. } => { - assert!(query.contains("1.0 AS __ggsql_stat_width")); + assert!(query.contains("1.0 AS \"__ggsql_stat_width")); assert!(stat_columns.contains(&"width".to_string())); } _ => panic!("Expected Transformed"), @@ -883,8 +883,8 @@ mod tests { assert!(query.contains("__ggsql_aes_fill__")); // Should NOT include width/height as pass-through (they're consumed) // They should only appear as stat columns - assert!(query.contains("__ggsql_aes_width__ AS __ggsql_stat_width")); - assert!(query.contains("__ggsql_aes_height__ AS __ggsql_stat_height")); + assert!(query.contains("__ggsql_aes_width__ AS \"__ggsql_stat_width")); + assert!(query.contains("__ggsql_aes_height__ AS \"__ggsql_stat_height")); } } @@ -909,8 +909,8 @@ mod tests { if let Ok(StatResult::Transformed { query, .. }) = result { // Should use SETTING values as SQL literals - assert!(query.contains("0.7 AS __ggsql_stat_width")); - assert!(query.contains("0.9 AS __ggsql_stat_height")); + assert!(query.contains("0.7 AS \"__ggsql_stat_width")); + assert!(query.contains("0.9 AS \"__ggsql_stat_height")); } } } diff --git a/src/plot/layer/geom/smooth.rs b/src/plot/layer/geom/smooth.rs index fad14432..806e2696 100644 --- a/src/plot/layer/geom/smooth.rs +++ b/src/plot/layer/geom/smooth.rs @@ -172,13 +172,13 @@ fn stat_ols(query: &str, aesthetics: &Mappings, group_by: &[String]) -> Result Result Date: Tue, 31 Mar 2026 13:18:16 +0100 Subject: [PATCH 14/20] Fix casting in median calculation --- ggsql-jupyter/src/data_explorer.rs | 18 ++++++++++++------ src/plot/scale/scale_type/mod.rs | 2 +- src/reader/mod.rs | 4 ++-- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/ggsql-jupyter/src/data_explorer.rs b/ggsql-jupyter/src/data_explorer.rs index 56e5568e..2578c37d 100644 --- a/ggsql-jupyter/src/data_explorer.rs +++ b/ggsql-jupyter/src/data_explorer.rs @@ -423,17 +423,22 @@ impl DataExplorerState { if wants_summary { match display { "integer" | "floating" => { + let float_type = dialect.number_type_name().unwrap_or("DOUBLE PRECISION"); select_parts.push(format!("MIN({}) AS min_val", quoted_col)); select_parts.push(format!("MAX({}) AS max_val", quoted_col)); - select_parts.push(format!("AVG(CAST({} AS DOUBLE)) AS mean_val", quoted_col)); + select_parts.push(format!( + "AVG(CAST({} AS {})) AS mean_val", + quoted_col, float_type + )); // Stddev: fetch raw aggregates, compute in Rust select_parts.push(format!( - "SUM(CAST({c} AS DOUBLE) * CAST({c} AS DOUBLE)) AS sum_sq", - c = quoted_col + "SUM(CAST({c} AS {t}) * CAST({c} AS {t})) AS sum_sq", + c = quoted_col, + t = float_type )); select_parts.push(format!( - "SUM(CAST({} AS DOUBLE)) AS sum_val", - quoted_col + "SUM(CAST({} AS {})) AS sum_val", + quoted_col, float_type )); select_parts.push(format!("COUNT({}) AS cnt", quoted_col)); } @@ -539,8 +544,9 @@ impl DataExplorerState { // Median via dialect's sql_percentile (uses QUANTILE_CONT on // DuckDB, NTILE fallback on other backends) let col_name = col.name.replace('"', "\"\""); + let from_query = format!("SELECT * FROM {}", self.table_path); let median_expr = - dialect.sql_percentile(&col_name, 0.5, &self.table_path, &[]); + dialect.sql_percentile(&col_name, 0.5, &from_query, &[]); let median_sql = format!("SELECT {} AS median_val", median_expr); if let Ok(median_df) = reader.execute_sql(&median_sql) { if let Some(v) = median_df diff --git a/src/plot/scale/scale_type/mod.rs b/src/plot/scale/scale_type/mod.rs index baac8f4b..d54dfdae 100644 --- a/src/plot/scale/scale_type/mod.rs +++ b/src/plot/scale/scale_type/mod.rs @@ -3394,7 +3394,7 @@ mod tests { let dialect = AnsiDialect; assert_eq!( dialect.type_name_for(CastTargetType::Number), - Some("DOUBLE") + Some("DOUBLE PRECISION") ); assert_eq!( dialect.type_name_for(CastTargetType::Integer), diff --git a/src/reader/mod.rs b/src/reader/mod.rs index e71bd0b0..6c04e4d5 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -46,9 +46,9 @@ use crate::{DataFrame, GgsqlError, Result}; /// /// Default implementations produce portable ANSI SQL. pub trait SqlDialect { - /// SQL type name for numeric columns (e.g., "DOUBLE") + /// SQL type name for numeric columns (e.g., "DOUBLE PRECISION") fn number_type_name(&self) -> Option<&str> { - Some("DOUBLE") + Some("DOUBLE PRECISION") } /// SQL type name for integer columns (e.g., "BIGINT") From 70b5bc7e3e6099681a6afa4fe665f98d32ffce69 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Tue, 31 Mar 2026 13:47:18 +0100 Subject: [PATCH 15/20] Support sparklines in data explorer --- ggsql-jupyter/src/data_explorer.rs | 292 +++++++++++++++++++++++++++-- 1 file changed, 276 insertions(+), 16 deletions(-) diff --git a/ggsql-jupyter/src/data_explorer.rs b/ggsql-jupyter/src/data_explorer.rs index 2578c37d..e8079000 100644 --- a/ggsql-jupyter/src/data_explorer.rs +++ b/ggsql-jupyter/src/data_explorer.rs @@ -189,7 +189,9 @@ impl DataExplorerState { "support_status": "supported", "supported_types": [ {"profile_type": "null_count", "support_status": "supported"}, - {"profile_type": "summary_stats", "support_status": "supported"} + {"profile_type": "summary_stats", "support_status": "supported"}, + {"profile_type": "small_histogram", "support_status": "supported"}, + {"profile_type": "small_frequency_table", "support_status": "supported"} ] }, "set_sort_columns": { @@ -395,6 +397,8 @@ impl DataExplorerState { let mut wants_null_count = false; let mut wants_summary = false; + let mut histogram_params: Option<&Value> = None; + let mut freq_table_params: Option<&Value> = None; for spec in specs { match spec .get("profile_type") @@ -403,6 +407,8 @@ impl DataExplorerState { { "null_count" => wants_null_count = true, "summary_stats" => wants_summary = true, + "small_histogram" => histogram_params = spec.get("params"), + "small_frequency_table" => freq_table_params = spec.get("params"), _ => {} } } @@ -416,7 +422,7 @@ impl DataExplorerState { let mut select_parts = Vec::new(); if wants_null_count { select_parts.push(format!( - "SUM(CASE WHEN {} IS NULL THEN 1 ELSE 0 END) AS null_count", + "SUM(CASE WHEN {} IS NULL THEN 1 ELSE 0 END) AS \"null_count\"", quoted_col )); } @@ -424,47 +430,47 @@ impl DataExplorerState { match display { "integer" | "floating" => { let float_type = dialect.number_type_name().unwrap_or("DOUBLE PRECISION"); - select_parts.push(format!("MIN({}) AS min_val", quoted_col)); - select_parts.push(format!("MAX({}) AS max_val", quoted_col)); + select_parts.push(format!("MIN({}) AS \"min_val\"", quoted_col)); + select_parts.push(format!("MAX({}) AS \"max_val\"", quoted_col)); select_parts.push(format!( - "AVG(CAST({} AS {})) AS mean_val", + "AVG(CAST({} AS {})) AS \"mean_val\"", quoted_col, float_type )); // Stddev: fetch raw aggregates, compute in Rust select_parts.push(format!( - "SUM(CAST({c} AS {t}) * CAST({c} AS {t})) AS sum_sq", + "SUM(CAST({c} AS {t}) * CAST({c} AS {t})) AS \"sum_sq\"", c = quoted_col, t = float_type )); select_parts.push(format!( - "SUM(CAST({} AS {})) AS sum_val", + "SUM(CAST({} AS {})) AS \"sum_val\"", quoted_col, float_type )); - select_parts.push(format!("COUNT({}) AS cnt", quoted_col)); + select_parts.push(format!("COUNT({}) AS \"cnt\"", quoted_col)); } "boolean" => { let true_lit = dialect.sql_boolean_literal(true); let false_lit = dialect.sql_boolean_literal(false); select_parts.push(format!( - "SUM(CASE WHEN {} = {} THEN 1 ELSE 0 END) AS true_count", + "SUM(CASE WHEN {} = {} THEN 1 ELSE 0 END) AS \"true_count\"", quoted_col, true_lit )); select_parts.push(format!( - "SUM(CASE WHEN {} = {} THEN 1 ELSE 0 END) AS false_count", + "SUM(CASE WHEN {} = {} THEN 1 ELSE 0 END) AS \"false_count\"", quoted_col, false_lit )); } "string" => { - select_parts.push(format!("COUNT(DISTINCT {}) AS num_unique", quoted_col)); + select_parts.push(format!("COUNT(DISTINCT {}) AS \"num_unique\"", quoted_col)); select_parts.push(format!( - "SUM(CASE WHEN {} = '' THEN 1 ELSE 0 END) AS num_empty", + "SUM(CASE WHEN {} = '' THEN 1 ELSE 0 END) AS \"num_empty\"", quoted_col )); } "date" | "datetime" => { - select_parts.push(format!("MIN({}) AS min_val", quoted_col)); - select_parts.push(format!("MAX({}) AS max_val", quoted_col)); - select_parts.push(format!("COUNT(DISTINCT {}) AS num_unique", quoted_col)); + select_parts.push(format!("MIN({}) AS \"min_val\"", quoted_col)); + select_parts.push(format!("MAX({}) AS \"max_val\"", quoted_col)); + select_parts.push(format!("COUNT(DISTINCT {}) AS \"num_unique\"", quoted_col)); } _ => {} } @@ -547,7 +553,7 @@ impl DataExplorerState { let from_query = format!("SELECT * FROM {}", self.table_path); let median_expr = dialect.sql_percentile(&col_name, 0.5, &from_query, &[]); - let median_sql = format!("SELECT {} AS median_val", median_expr); + let median_sql = format!("SELECT {} AS \"median_val\"", median_expr); if let Ok(median_df) = reader.execute_sql(&median_sql) { if let Some(v) = median_df .column("median_val") @@ -624,8 +630,262 @@ impl DataExplorerState { result["summary_stats"] = stats; } + // Compute histogram if requested (only for numeric types) + if let Some(params) = histogram_params { + if matches!(display, "integer" | "floating") { + if let Some(hist) = self.compute_histogram(col, params, reader) { + result["small_histogram"] = hist; + } + } + } + + // Compute frequency table if requested (for string/boolean types) + if let Some(params) = freq_table_params { + if matches!(display, "string" | "boolean") { + if let Some(ft) = self.compute_frequency_table(col, params, reader) { + result["small_frequency_table"] = ft; + } + } + } + result } + + /// Compute a histogram for a numeric column. + fn compute_histogram( + &self, + col: &ColumnInfo, + params: &Value, + reader: &dyn Reader, + ) -> Option { + let max_bins = params + .get("num_bins") + .and_then(|v| v.as_u64()) + .unwrap_or(20) as usize; + + if max_bins == 0 { + return None; + } + + let dialect = reader.dialect(); + let float_type = dialect.number_type_name().unwrap_or("DOUBLE PRECISION"); + let quoted_col = format!("\"{}\"", col.name.replace('"', "\"\"")); + let is_integer = col.type_display == "integer"; + + // Get min, max, count in one query + let bounds_sql = format!( + "SELECT \ + MIN(CAST({c} AS {t})) AS \"min_val\", \ + MAX(CAST({c} AS {t})) AS \"max_val\", \ + COUNT({c}) AS \"cnt\" \ + FROM {table} WHERE {c} IS NOT NULL", + c = quoted_col, + t = float_type, + table = self.table_path, + ); + + let bounds_df = reader.execute_sql(&bounds_sql).ok()?; + let get_f64 = |name: &str| -> Option { + bounds_df + .column(name) + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| { + if v.is_null() { + None + } else { + format!("{}", v).trim_matches('"').parse::().ok() + } + }) + }; + + let min_val = get_f64("min_val")?; + let max_val = get_f64("max_val")?; + let count = get_f64("cnt").unwrap_or(0.0) as usize; + + // Handle edge case: all values identical + if (max_val - min_val).abs() < f64::EPSILON { + return Some(json!({ + "bin_edges": [format!("{}", min_val), format!("{}", max_val)], + "bin_counts": [count as i64], + "quantiles": [] + })); + } + + // Determine actual bin count using Sturges' formula, capped at max_bins. + // For integers, also cap at (max - min + 1) to avoid sub-unit bins. + let mut num_bins = if count > 1 { + ((count as f64).log2().ceil() as usize + 1).max(1) + } else { + 1 + }; + if is_integer { + let int_range = (max_val - min_val) as usize + 1; + num_bins = num_bins.min(int_range); + } + num_bins = num_bins.min(max_bins).max(1); + + let bin_width = (max_val - min_val) / num_bins as f64; + + // Bin the data using FLOOR. Clamp the last bin to num_bins-1 so + // max value doesn't create an extra bin. + let hist_sql = format!( + "SELECT \ + CASE \ + WHEN \"bin\" >= {num_bins} THEN {last_bin} \ + ELSE \"bin\" \ + END AS \"clamped_bin\", \ + COUNT(*) AS \"cnt\" \ + FROM ( \ + SELECT FLOOR((CAST({c} AS {t}) - {min}) / {width}) AS \"bin\" \ + FROM {table} \ + WHERE {c} IS NOT NULL \ + ) AS \"__bins__\" \ + GROUP BY \"clamped_bin\" \ + ORDER BY \"clamped_bin\"", + c = quoted_col, + t = float_type, + table = self.table_path, + min = min_val, + width = bin_width, + num_bins = num_bins, + last_bin = num_bins - 1, + ); + + let hist_df = reader.execute_sql(&hist_sql).ok()?; + + // Build bin_edges: num_bins + 1 edges + let bin_edges: Vec = (0..=num_bins) + .map(|i| format!("{}", min_val + i as f64 * bin_width)) + .collect(); + + // Build bin_counts: fill from query results (sparse bins get 0) + let mut bin_counts = vec![0i64; num_bins]; + let bin_col = hist_df.column("clamped_bin").ok()?; + let cnt_col = hist_df.column("cnt").ok()?; + for i in 0..hist_df.height() { + if let (Ok(bin_val), Ok(cnt_val)) = (bin_col.get(i), cnt_col.get(i)) { + let bin_str = format!("{}", bin_val); + // Parse bin index — may be float (e.g., "3.0") on some backends + if let Ok(bin_idx) = bin_str.parse::() { + let idx = bin_idx as usize; + if idx < num_bins { + let count_str = format!("{}", cnt_val); + bin_counts[idx] = count_str.parse::().unwrap_or(0); + } + } + } + } + + // Compute requested quantiles + let quantiles_param = params + .get("quantiles") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default(); + + let mut quantile_results = Vec::new(); + let from_query = format!("SELECT * FROM {}", self.table_path); + let col_name = col.name.replace('"', "\"\""); + for q in &quantiles_param { + if let Some(q_val) = q.as_f64() { + let expr = dialect.sql_percentile(&col_name, q_val, &from_query, &[]); + let q_sql = format!("SELECT {} AS \"q_val\"", expr); + if let Ok(q_df) = reader.execute_sql(&q_sql) { + if let Some(v) = q_df + .column("q_val") + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| { + if v.is_null() { + None + } else { + Some(format!("{}", v).trim_matches('"').to_string()) + } + }) + { + quantile_results.push(json!({"q": q_val, "value": v})); + } + } + } + } + + Some(json!({ + "bin_edges": bin_edges, + "bin_counts": bin_counts, + "quantiles": quantile_results + })) + } + + /// Compute a frequency table for a string or boolean column. + fn compute_frequency_table( + &self, + col: &ColumnInfo, + params: &Value, + reader: &dyn Reader, + ) -> Option { + let limit = params + .get("limit") + .and_then(|v| v.as_u64()) + .unwrap_or(8) as usize; + + let quoted_col = format!("\"{}\"", col.name.replace('"', "\"\"")); + + let sql = format!( + "SELECT {c} AS \"value\", COUNT(*) AS \"count\" \ + FROM {table} \ + WHERE {c} IS NOT NULL \ + GROUP BY {c} \ + ORDER BY COUNT(*) DESC \ + LIMIT {limit}", + c = quoted_col, + table = self.table_path, + limit = limit, + ); + + let df = reader.execute_sql(&sql).ok()?; + + let val_col = df.column("value").ok()?; + let cnt_col = df.column("count").ok()?; + + let mut values = Vec::new(); + let mut counts = Vec::new(); + let mut top_total: i64 = 0; + + for i in 0..df.height() { + if let (Ok(v), Ok(c)) = (val_col.get(i), cnt_col.get(i)) { + let val_str = format!("{}", v).trim_matches('"').to_string(); + let count: i64 = format!("{}", c).parse().unwrap_or(0); + values.push(Value::String(val_str)); + counts.push(count); + top_total += count; + } + } + + // Compute other_count: total non-null rows minus the top-K sum + let count_sql = format!( + "SELECT COUNT({c}) AS \"total\" FROM {table}", + c = quoted_col, + table = self.table_path, + ); + let other_count = reader + .execute_sql(&count_sql) + .ok() + .and_then(|df| { + df.column("total") + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| format!("{}", v).parse::().ok()) + }) + .map(|total| total - top_total) + .unwrap_or(0); + + Some(json!({ + "values": values, + "counts": counts, + "other_count": other_count + })) + } } /// Map a SQL type name (from information_schema) to a Positron display type. From b0f25a8e9c330a3eabcf0caa88af11846e9bf7d4 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Tue, 31 Mar 2026 14:19:08 +0100 Subject: [PATCH 16/20] Snowflake issues --- ggsql-jupyter/src/data_explorer.rs | 67 ++++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 4 deletions(-) diff --git a/ggsql-jupyter/src/data_explorer.rs b/ggsql-jupyter/src/data_explorer.rs index e8079000..a5794141 100644 --- a/ggsql-jupyter/src/data_explorer.rs +++ b/ggsql-jupyter/src/data_explorer.rs @@ -78,7 +78,7 @@ impl DataExplorerState { ); // Get row count - let count_sql = format!("SELECT COUNT(*) AS n FROM {}", table_path); + let count_sql = format!("SELECT COUNT(*) AS \"n\" FROM {}", table_path); let count_df = reader .execute_sql(&count_sql) .map_err(|e| format!("Failed to count rows: {}", e))?; @@ -110,8 +110,9 @@ impl DataExplorerState { for i in 0..columns_df.height() { if let (Ok(name_val), Ok(type_val)) = (name_col.get(i), type_col.get(i)) { let name = name_val.to_string().trim_matches('"').to_string(); - let type_name = type_val.to_string().trim_matches('"').to_string(); - let type_display = sql_type_to_display(&type_name).to_string(); + let raw_type = type_val.to_string().trim_matches('"').to_string(); + let type_display = sql_type_to_display(&raw_type).to_string(); + let type_name = clean_type_name(&raw_type); columns.push(ColumnInfo { name, type_name, @@ -888,8 +889,39 @@ impl DataExplorerState { } } -/// Map a SQL type name (from information_schema) to a Positron display type. +/// Map a SQL type name (from information_schema or SHOW COLUMNS) to a Positron display type. +/// +/// Handles both simple type names (e.g. "INTEGER", "VARCHAR") and Snowflake's +/// JSON format (e.g. `{"type":"FIXED","precision":38,"scale":0,...}`). fn sql_type_to_display(type_name: &str) -> &'static str { + // Handle Snowflake JSON type format + if type_name.starts_with('{') { + if let Ok(obj) = serde_json::from_str::(type_name) { + if let Some(t) = obj.get("type").and_then(|v| v.as_str()) { + return match t { + "FIXED" => { + let scale = obj.get("scale").and_then(|v| v.as_i64()).unwrap_or(0); + if scale > 0 { + "floating" + } else { + "integer" + } + } + "REAL" | "FLOAT" => "floating", + "TEXT" => "string", + "BOOLEAN" => "boolean", + "DATE" => "date", + "TIMESTAMP_NTZ" | "TIMESTAMP_LTZ" | "TIMESTAMP_TZ" => "datetime", + "TIME" => "time", + "BINARY" => "string", + "VARIANT" | "OBJECT" | "ARRAY" => "string", + _ => "unknown", + }; + } + } + } + + // Simple type names (DuckDB, PostgreSQL, SQLite, etc.) let upper = type_name.to_uppercase(); let upper = upper.as_str(); @@ -931,6 +963,33 @@ fn sql_type_to_display(type_name: &str) -> &'static str { "unknown" } +/// Clean up a raw type name for display in the schema response. +/// +/// For Snowflake JSON types, extracts the `type` field (e.g. "NUMBER", "TEXT"). +/// For simple type names, returns as-is. +fn clean_type_name(type_name: &str) -> String { + if type_name.starts_with('{') { + if let Ok(obj) = serde_json::from_str::(type_name) { + if let Some(t) = obj.get("type").and_then(|v| v.as_str()) { + return match t { + "FIXED" => { + let scale = obj.get("scale").and_then(|v| v.as_i64()).unwrap_or(0); + if scale > 0 { + format!("NUMBER({},{})", + obj.get("precision").and_then(|v| v.as_i64()).unwrap_or(38), + scale) + } else { + "NUMBER".to_string() + } + } + other => other.to_string(), + }; + } + } + } + type_name.to_string() +} + #[cfg(test)] mod tests { use super::*; From f76a7d87cc9c5c6a9590a294b2a8f85984e91f07 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Tue, 31 Mar 2026 14:55:46 +0100 Subject: [PATCH 17/20] More fun with quoting --- src/execute/layer.rs | 4 +- src/naming.rs | 15 ++++ src/plot/layer/geom/bar.rs | 4 +- src/plot/layer/geom/boxplot.rs | 10 +-- src/plot/layer/geom/density.rs | 36 ++++----- src/plot/layer/geom/histogram.rs | 6 +- src/plot/layer/geom/rect.rs | 100 ++++++++++++------------ src/plot/layer/geom/smooth.rs | 10 +-- src/plot/layer/geom/types.rs | 5 ++ src/plot/scale/scale_type/binned.rs | 3 +- src/plot/scale/scale_type/continuous.rs | 5 +- src/reader/duckdb.rs | 4 +- src/reader/mod.rs | 4 +- 13 files changed, 114 insertions(+), 92 deletions(-) diff --git a/src/execute/layer.rs b/src/execute/layer.rs index 2afd270a..6646f80f 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -314,9 +314,9 @@ pub fn apply_pre_stat_transform( .filter(|col| seen.insert(&col.name)) .map(|col| { if let Some((_, sql)) = transform_exprs.iter().find(|(c, _)| c == &col.name) { - format!("{} AS \"{}\"", sql, col.name) + format!("{} AS {}", sql, naming::quote_ident(&col.name)) } else { - format!("\"{}\"", col.name) + naming::quote_ident(&col.name) } }) .collect(); diff --git a/src/naming.rs b/src/naming.rs index 882f40dc..b36a05ec 100644 --- a/src/naming.rs +++ b/src/naming.rs @@ -224,6 +224,21 @@ pub fn aesthetic_column(aesthetic: &str) -> String { format!("{}{}{}", AES_PREFIX, aesthetic, GGSQL_SUFFIX) } +// ============================================================================ +// SQL Quoting +// ============================================================================ + +/// Double-quote a SQL identifier for case-preserving databases (e.g. Snowflake). +/// +/// # Example +/// ``` +/// use ggsql::naming; +/// assert_eq!(naming::quote_ident("__ggsql_aes_x__"), "\"__ggsql_aes_x__\""); +/// ``` +pub fn quote_ident(name: &str) -> String { + format!("\"{}\"", name) +} + // ============================================================================ // Detection Functions // ============================================================================ diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index a3c301ed..f62aa5bb 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -178,7 +178,7 @@ fn stat_bar_count( if let Some(weight_col) = weight_value.column_name() { if schema_columns.contains(weight_col) { // weight column exists - use SUM (but still call it "count") - format!("SUM({}) AS \"{}\"", weight_col, stat_count) + format!("SUM({}) AS \"{}\"", naming::quote_ident(weight_col), stat_count) } else { // weight mapped but column doesn't exist - fall back to COUNT // (this shouldn't happen with upfront validation, but handle gracefully) @@ -264,7 +264,7 @@ fn stat_bar_count( ) } else { // x is mapped - use existing logic with two-stage query - let x_col = x_col.unwrap(); + let x_col = naming::quote_ident(&x_col.unwrap()); // Build grouped columns (group_by includes partition_by + facet variables + x) let group_cols = if group_by.is_empty() { diff --git a/src/plot/layer/geom/boxplot.rs b/src/plot/layer/geom/boxplot.rs index e7f6f507..54633d7c 100644 --- a/src/plot/layer/geom/boxplot.rs +++ b/src/plot/layer/geom/boxplot.rs @@ -169,7 +169,7 @@ fn boxplot_sql_compute_summary( coef: &f64, dialect: &dyn SqlDialect, ) -> String { - let quoted_groups: Vec = groups.iter().map(|g| format!("\"{}\"", g)).collect(); + let quoted_groups: Vec = groups.iter().map(|g| naming::quote_ident(g)).collect(); let groups_str = quoted_groups.join(", "); let lower_expr = dialect.sql_greatest(&[&format!("q1 - {coef} * (q3 - q1)"), "min"]); let upper_expr = dialect.sql_least(&[&format!("q3 + {coef} * (q3 - q1)"), "max"]); @@ -178,7 +178,7 @@ fn boxplot_sql_compute_summary( let q3 = dialect.sql_percentile(value, 0.75, from, groups); let qt = "\"__ggsql_qt__\""; let fn_alias = "\"__ggsql_fn__\""; - let quoted_value = format!("\"{}\"", value); + let quoted_value = naming::quote_ident(value); format!( "SELECT *, @@ -211,12 +211,12 @@ fn boxplot_sql_filter_outliers(groups: &[String], value: &str, from: &str) -> St let mut join_pairs = Vec::new(); let mut keep_columns = Vec::new(); for column in groups { - let quoted = format!("\"{}\"", column); + let quoted = naming::quote_ident(column); join_pairs.push(format!("raw.{} = summary.{}", quoted, quoted)); keep_columns.push(format!("raw.{}", quoted)); } - let quoted_value = format!("\"{}\"", value); + let quoted_value = naming::quote_ident(value); // We're joining outliers with the summary to use the lower/upper whisker // values as a filter format!( @@ -245,7 +245,7 @@ fn boxplot_sql_append_outliers( let value2_name = naming::stat_column("value2"); let type_name = naming::stat_column("type"); - let quoted_groups: Vec = groups.iter().map(|g| format!("\"{}\"", g)).collect(); + let quoted_groups: Vec = groups.iter().map(|g| naming::quote_ident(g)).collect(); let groups_str = quoted_groups.join(", "); // Helper to build visual-element rows from summary table diff --git a/src/plot/layer/geom/density.rs b/src/plot/layer/geom/density.rs index 90eb072c..cfeaf556 100644 --- a/src/plot/layer/geom/density.rs +++ b/src/plot/layer/geom/density.rs @@ -229,7 +229,7 @@ fn density_sql_bandwidth( let (groups_select, group_by) = if groups.is_empty() { (String::new(), String::new()) } else { - let quoted_groups: Vec = groups.iter().map(|g| format!("\"{}\"", g)).collect(); + let quoted_groups: Vec = groups.iter().map(|g| naming::quote_ident(g)).collect(); let groups_str = quoted_groups.join(", "); ( format!("\n {},", groups_str), @@ -237,7 +237,7 @@ fn density_sql_bandwidth( ) }; - let quoted_value = format!("\"{}\"", value); + let quoted_value = naming::quote_ident(value); format!( "WITH RECURSIVE bandwidth AS ( @@ -266,7 +266,7 @@ fn silverman_rule( // The query computes Silverman's rule of thumb (R's `stats::bw.nrd0()`). // We absorb the adjustment in the 0.9 multiplier of the rule let adjust = 0.9 * adjust; - let v = format!("\"{}\"", value_column); + let v = naming::quote_ident(value_column); let stddev = format!("SQRT(AVG({v}*{v}) - AVG({v})*AVG({v}))", v = v); let q75 = dialect.sql_percentile(value_column, 0.75, from, groups); let q25 = dialect.sql_percentile(value_column, 0.25, from, groups); @@ -354,28 +354,28 @@ fn build_data_cte( ) -> String { // Include weight column if provided, otherwise default to 1.0 let weight_col = if let Some(w) = weight { - format!(", \"{}\" AS weight", w) + format!(", {} AS weight", naming::quote_ident(w)) } else { ", 1.0 AS weight".to_string() }; let smooth_col = if let Some(s) = smooth { - format!(", \"{}\"", s) + format!(", {}", naming::quote_ident(s)) } else { "".to_string() }; - let quoted_value = format!("\"{}\"", value); + let quoted_value = naming::quote_ident(value); // Only filter out nulls in value column, keep NULLs in group columns let mut filter_valid = format!("{} IS NOT NULL", quoted_value); if let Some(s) = smooth { filter_valid = format!( - "{filter} AND \"{}\" IS NOT NULL", - s, + "{filter} AND {} IS NOT NULL", + naming::quote_ident(s), filter = filter_valid, ); } - let quoted_groups: Vec = group_by.iter().map(|g| format!("\"{}\"", g)).collect(); + let quoted_groups: Vec = group_by.iter().map(|g| naming::quote_ident(g)).collect(); format!( "data AS ( SELECT {groups}{value} AS val{weight_col}{smooth_col} @@ -430,7 +430,7 @@ fn build_grid_cte( x_formula = x_formula ) } else { - let quoted_groups: Vec = groups.iter().map(|g| format!("\"{}\"", g)).collect(); + let quoted_groups: Vec = groups.iter().map(|g| naming::quote_ident(g)).collect(); let groups_str = quoted_groups.join(", "); // When tails is specified, create full_grid; otherwise create grid directly let cte_name = if tails.is_some() { "full_grid" } else { "grid" }; @@ -455,14 +455,12 @@ fn build_grid_cte( let bandwidth_join_conds: Vec = groups .iter() .map(|g| { - format!( - "full_grid.\"{}\" IS NOT DISTINCT FROM bandwidth.\"{}\"", - g, g - ) + let q = naming::quote_ident(g); + format!("full_grid.{q} IS NOT DISTINCT FROM bandwidth.{q}") }) .collect(); let grid_groups_select: Vec = - groups.iter().map(|g| format!("full_grid.\"{}\"", g)).collect(); + groups.iter().map(|g| format!("full_grid.{}", naming::quote_ident(g))).collect(); format!( "{seq_cte}, @@ -519,7 +517,7 @@ fn compute_density( } else { group_by .iter() - .map(|g| format!("data.\"{}\" IS NOT DISTINCT FROM bandwidth.\"{}\"", g, g)) + .map(|g| { let q = naming::quote_ident(g); format!("data.{q} IS NOT DISTINCT FROM bandwidth.{q}") }) .collect::>() .join(" AND ") }; @@ -530,7 +528,7 @@ fn compute_density( } else { let grid_data_conds: Vec = group_by .iter() - .map(|g| format!("grid.\"{}\" IS NOT DISTINCT FROM data.\"{}\"", g, g)) + .map(|g| { let q = naming::quote_ident(g); format!("grid.{q} IS NOT DISTINCT FROM data.{q}") }) .collect(); format!("WHERE {}", grid_data_conds.join(" AND ")) }; @@ -544,7 +542,7 @@ fn compute_density( ); // Build group-related SQL fragments - let grid_groups: Vec = group_by.iter().map(|g| format!("grid.\"{}\"", g)).collect(); + let grid_groups: Vec = group_by.iter().map(|g| format!("grid.{}", naming::quote_ident(g))).collect(); let aggregation = format!( "GROUP BY grid.x{grid_group_by} ORDER BY grid.x{grid_group_by}", @@ -554,7 +552,7 @@ fn compute_density( let groups = if group_by.is_empty() { String::new() } else { - let quoted: Vec = group_by.iter().map(|g| format!("\"{}\"", g)).collect(); + let quoted: Vec = group_by.iter().map(|g| naming::quote_ident(g)).collect(); format!("{},", quoted.join(", ")) }; diff --git a/src/plot/layer/geom/histogram.rs b/src/plot/layer/geom/histogram.rs index 58a660a8..9176956e 100644 --- a/src/plot/layer/geom/histogram.rs +++ b/src/plot/layer/geom/histogram.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; -use super::types::{get_column_name, CLOSED_VALUES, POSITION_VALUES}; +use super::types::{get_quoted_column_name, CLOSED_VALUES, POSITION_VALUES}; use super::{ DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, StatResult, @@ -125,7 +125,7 @@ fn stat_histogram( dialect: &dyn SqlDialect, ) -> Result { // Get x column name from aesthetics - let x_col = get_column_name(aesthetics, "pos1").ok_or_else(|| { + let x_col = get_quoted_column_name(aesthetics, "pos1").ok_or_else(|| { GgsqlError::ValidationError("Histogram requires 'x' aesthetic mapping".to_string()) })?; @@ -213,7 +213,7 @@ fn stat_histogram( )); } if let Some(weight_col) = weight_value.column_name() { - format!("SUM({})", weight_col) + format!("SUM({})", naming::quote_ident(weight_col)) } else { "COUNT(*)".to_string() } diff --git a/src/plot/layer/geom/rect.rs b/src/plot/layer/geom/rect.rs index 840276b2..6aeeb928 100644 --- a/src/plot/layer/geom/rect.rs +++ b/src/plot/layer/geom/rect.rs @@ -2,8 +2,8 @@ use std::collections::HashMap; -use super::types::get_column_name; use super::types::POSITION_VALUES; +use super::types::{get_column_name, get_quoted_column_name}; use super::{DefaultAesthetics, GeomTrait, GeomType, ParamConstraint, StatResult}; use crate::naming; use crate::plot::types::{DefaultAestheticValue, ParameterValue}; @@ -130,15 +130,17 @@ fn process_direction( _ => unreachable!("axis must be 'x' or 'y'"), }; - // Get column names from MAPPING, with SETTING fallback for size - let center = get_column_name(aesthetics, center_aes); - let min = get_column_name(aesthetics, min_aes); - let max = get_column_name(aesthetics, max_aes); - let size = get_column_name(aesthetics, size_aes) + // Get unquoted center name for schema lookup + let center_unquoted = get_column_name(aesthetics, center_aes); + let center = center_unquoted.as_deref().map(naming::quote_ident); + let min = get_quoted_column_name(aesthetics, min_aes); + let max = get_quoted_column_name(aesthetics, max_aes); + // SETTING fallback for size is a literal value, no quoting needed. + let size = get_quoted_column_name(aesthetics, size_aes) .or_else(|| parameters.get(size_aes).map(|v| v.to_string())); // Detect if discrete by checking schema - let is_discrete = center + let is_discrete = center_unquoted .as_ref() .and_then(|col| schema.iter().find(|c| &c.name == col)) .map(|c| c.is_discrete) @@ -208,7 +210,7 @@ fn stat_rect( let mut select_parts: Vec = schema .iter() .filter(|col| !consumed_columns.contains(&col.name)) - .map(|col| col.name.clone()) + .map(|col| naming::quote_ident(&col.name)) .collect(); // Add X direction SELECT parts and collect stat columns @@ -223,7 +225,7 @@ fn stat_rect( // Build transformed query let transformed_query = format!( - "SELECT {} FROM ({}) AS __ggsql_rect_stat__", + "SELECT {} FROM ({}) AS \"__ggsql_rect_stat__\"", select_list, query ); @@ -446,44 +448,44 @@ mod tests { ( "xmin + xmax", vec!["pos1min", "pos1max"], - "__ggsql_aes_pos1min__", - "__ggsql_aes_pos1max__", + "\"__ggsql_aes_pos1min__\"", + "\"__ggsql_aes_pos1max__\"", ), ( "x + width", vec!["pos1", "width"], - "(__ggsql_aes_pos1__ - __ggsql_aes_width__ / 2.0)", - "(__ggsql_aes_pos1__ + __ggsql_aes_width__ / 2.0)", + "(\"__ggsql_aes_pos1__\" - \"__ggsql_aes_width__\" / 2.0)", + "(\"__ggsql_aes_pos1__\" + \"__ggsql_aes_width__\" / 2.0)", ), ( "x only (default width 1.0)", vec!["pos1"], - "(__ggsql_aes_pos1__ - 0.5)", - "(__ggsql_aes_pos1__ + 0.5)", + "(\"__ggsql_aes_pos1__\" - 0.5)", + "(\"__ggsql_aes_pos1__\" + 0.5)", ), ( "x + xmin", vec!["pos1", "pos1min"], - "__ggsql_aes_pos1min__", - "(2 * __ggsql_aes_pos1__ - __ggsql_aes_pos1min__)", + "\"__ggsql_aes_pos1min__\"", + "(2 * \"__ggsql_aes_pos1__\" - \"__ggsql_aes_pos1min__\")", ), ( "x + xmax", vec!["pos1", "pos1max"], - "(2 * __ggsql_aes_pos1__ - __ggsql_aes_pos1max__)", - "__ggsql_aes_pos1max__", + "(2 * \"__ggsql_aes_pos1__\" - \"__ggsql_aes_pos1max__\")", + "\"__ggsql_aes_pos1max__\"", ), ( "xmin + width", vec!["pos1min", "width"], - "__ggsql_aes_pos1min__", - "(__ggsql_aes_pos1min__ + __ggsql_aes_width__)", + "\"__ggsql_aes_pos1min__\"", + "(\"__ggsql_aes_pos1min__\" + \"__ggsql_aes_width__\")", ), ( "xmax + width", vec!["pos1max", "width"], - "(__ggsql_aes_pos1max__ - __ggsql_aes_width__)", - "__ggsql_aes_pos1max__", + "(\"__ggsql_aes_pos1max__\" - \"__ggsql_aes_width__\")", + "\"__ggsql_aes_pos1max__\"", ), ]; @@ -562,38 +564,38 @@ mod tests { ( "ymin + ymax", vec!["pos2min", "pos2max"], - "__ggsql_aes_pos2min__", - "__ggsql_aes_pos2max__", + "\"__ggsql_aes_pos2min__\"", + "\"__ggsql_aes_pos2max__\"", ), ( "y + height", vec!["pos2", "height"], - "(__ggsql_aes_pos2__ - __ggsql_aes_height__ / 2.0)", - "(__ggsql_aes_pos2__ + __ggsql_aes_height__ / 2.0)", + "(\"__ggsql_aes_pos2__\" - \"__ggsql_aes_height__\" / 2.0)", + "(\"__ggsql_aes_pos2__\" + \"__ggsql_aes_height__\" / 2.0)", ), ( "y + ymin", vec!["pos2", "pos2min"], - "__ggsql_aes_pos2min__", - "(2 * __ggsql_aes_pos2__ - __ggsql_aes_pos2min__)", + "\"__ggsql_aes_pos2min__\"", + "(2 * \"__ggsql_aes_pos2__\" - \"__ggsql_aes_pos2min__\")", ), ( "y + ymax", vec!["pos2", "pos2max"], - "(2 * __ggsql_aes_pos2__ - __ggsql_aes_pos2max__)", - "__ggsql_aes_pos2max__", + "(2 * \"__ggsql_aes_pos2__\" - \"__ggsql_aes_pos2max__\")", + "\"__ggsql_aes_pos2max__\"", ), ( "ymin + height", vec!["pos2min", "height"], - "__ggsql_aes_pos2min__", - "(__ggsql_aes_pos2min__ + __ggsql_aes_height__)", + "\"__ggsql_aes_pos2min__\"", + "(\"__ggsql_aes_pos2min__\" + \"__ggsql_aes_height__\")", ), ( "ymax + height", vec!["pos2max", "height"], - "(__ggsql_aes_pos2max__ - __ggsql_aes_height__)", - "__ggsql_aes_pos2max__", + "(\"__ggsql_aes_pos2max__\" - \"__ggsql_aes_height__\")", + "\"__ggsql_aes_pos2max__\"", ), ]; @@ -687,8 +689,8 @@ mod tests { .. }) = result { - assert!(query.contains("__ggsql_aes_pos1__ AS \"__ggsql_stat_pos1")); - assert!(query.contains("__ggsql_aes_width__ AS \"__ggsql_stat_width")); + assert!(query.contains("\"__ggsql_aes_pos1__\" AS \"__ggsql_stat_pos1")); + assert!(query.contains("\"__ggsql_aes_width__\" AS \"__ggsql_stat_width")); assert!(stat_columns.contains(&"pos1".to_string())); assert!(stat_columns.contains(&"width".to_string())); assert!(stat_columns.contains(&"pos2min".to_string())); @@ -718,8 +720,8 @@ mod tests { .. }) = result { - assert!(query.contains("__ggsql_aes_pos2__ AS \"__ggsql_stat_pos2")); - assert!(query.contains("__ggsql_aes_height__ AS \"__ggsql_stat_height")); + assert!(query.contains("\"__ggsql_aes_pos2__\" AS \"__ggsql_stat_pos2")); + assert!(query.contains("\"__ggsql_aes_height__\" AS \"__ggsql_stat_height")); assert!(stat_columns.contains(&"pos1min".to_string())); assert!(stat_columns.contains(&"pos1max".to_string())); assert!(stat_columns.contains(&"pos2".to_string())); @@ -749,10 +751,10 @@ mod tests { .. }) = result { - assert!(query.contains("__ggsql_aes_pos1__ AS \"__ggsql_stat_pos1")); - assert!(query.contains("__ggsql_aes_width__ AS \"__ggsql_stat_width")); - assert!(query.contains("__ggsql_aes_pos2__ AS \"__ggsql_stat_pos2")); - assert!(query.contains("__ggsql_aes_height__ AS \"__ggsql_stat_height")); + assert!(query.contains("\"__ggsql_aes_pos1__\" AS \"__ggsql_stat_pos1")); + assert!(query.contains("\"__ggsql_aes_width__\" AS \"__ggsql_stat_width")); + assert!(query.contains("\"__ggsql_aes_pos2__\" AS \"__ggsql_stat_pos2")); + assert!(query.contains("\"__ggsql_aes_height__\" AS \"__ggsql_stat_height")); assert_eq!(stat_columns.len(), 4); } } @@ -782,8 +784,8 @@ mod tests { stat_columns, .. } => { - assert!(query.contains("(__ggsql_aes_pos1__ - 0.5)")); - assert!(query.contains("(__ggsql_aes_pos1__ + 0.5)")); + assert!(query.contains("(\"__ggsql_aes_pos1__\" - 0.5)")); + assert!(query.contains("(\"__ggsql_aes_pos1__\" + 0.5)")); assert!(stat_columns.contains(&"pos1min".to_string())); assert!(stat_columns.contains(&"pos1max".to_string())); } @@ -879,12 +881,12 @@ mod tests { assert!(result.is_ok()); if let Ok(StatResult::Transformed { query, .. }) = result { - // Should include fill column (non-consumed aesthetic from schema) - assert!(query.contains("__ggsql_aes_fill__")); + // Should include fill column (non-consumed aesthetic from schema, quoted) + assert!(query.contains("\"__ggsql_aes_fill__\"")); // Should NOT include width/height as pass-through (they're consumed) // They should only appear as stat columns - assert!(query.contains("__ggsql_aes_width__ AS \"__ggsql_stat_width")); - assert!(query.contains("__ggsql_aes_height__ AS \"__ggsql_stat_height")); + assert!(query.contains("\"__ggsql_aes_width__\" AS \"__ggsql_stat_width")); + assert!(query.contains("\"__ggsql_aes_height__\" AS \"__ggsql_stat_height")); } } diff --git a/src/plot/layer/geom/smooth.rs b/src/plot/layer/geom/smooth.rs index 806e2696..d81ef8ff 100644 --- a/src/plot/layer/geom/smooth.rs +++ b/src/plot/layer/geom/smooth.rs @@ -4,7 +4,7 @@ use super::types::POSITION_VALUES; use super::{ DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, }; -use crate::plot::geom::types::get_column_name; +use crate::plot::geom::types::get_quoted_column_name; use crate::plot::types::DefaultAestheticValue; use crate::plot::{ParameterValue, StatResult}; use crate::reader::SqlDialect; @@ -136,10 +136,10 @@ impl std::fmt::Display for Smooth { } fn stat_ols(query: &str, aesthetics: &Mappings, group_by: &[String]) -> Result { - let x_col = get_column_name(aesthetics, "pos1").ok_or_else(|| { + let x_col = get_quoted_column_name(aesthetics, "pos1").ok_or_else(|| { GgsqlError::ValidationError("Smooth requires 'pos1' aesthetic".to_string()) })?; - let y_col = get_column_name(aesthetics, "pos2").ok_or_else(|| { + let y_col = get_quoted_column_name(aesthetics, "pos2").ok_or_else(|| { GgsqlError::ValidationError("Smooth requires 'pos2' aesthetic".to_string()) })?; @@ -198,10 +198,10 @@ fn stat_ols(query: &str, aesthetics: &Mappings, group_by: &[String]) -> Result Result { - let x_col = get_column_name(aesthetics, "pos1").ok_or_else(|| { + let x_col = get_quoted_column_name(aesthetics, "pos1").ok_or_else(|| { GgsqlError::ValidationError("Smooth requires 'pos1' aesthetic".to_string()) })?; - let y_col = get_column_name(aesthetics, "pos2").ok_or_else(|| { + let y_col = get_quoted_column_name(aesthetics, "pos2").ok_or_else(|| { GgsqlError::ValidationError("Smooth requires 'pos2' aesthetic".to_string()) })?; diff --git a/src/plot/layer/geom/types.rs b/src/plot/layer/geom/types.rs index ad7cf70a..b856737f 100644 --- a/src/plot/layer/geom/types.rs +++ b/src/plot/layer/geom/types.rs @@ -150,6 +150,11 @@ pub fn get_column_name(aesthetics: &Mappings, aesthetic: &str) -> Option }) } +/// Helper to extract a double-quoted column name for use in SQL expressions. +pub fn get_quoted_column_name(aesthetics: &Mappings, aesthetic: &str) -> Option { + get_column_name(aesthetics, aesthetic).map(|n| crate::naming::quote_ident(&n)) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/plot/scale/scale_type/binned.rs b/src/plot/scale/scale_type/binned.rs index e2f2b81d..8bdeca83 100644 --- a/src/plot/scale/scale_type/binned.rs +++ b/src/plot/scale/scale_type/binned.rs @@ -8,6 +8,7 @@ use super::{ expand_numeric_range, resolve_common_steps, ScaleDataContext, ScaleTypeKind, ScaleTypeTrait, TransformKind, CLOSED_VALUES, OOB_CENSOR, OOB_SQUISH, OOB_VALUES_BINNED, }; +use crate::naming; use crate::plot::types::{ ArrayConstraint, DefaultParamValue, NumberConstraint, ParamConstraint, ParamDefinition, }; @@ -727,7 +728,7 @@ fn build_bin_condition( (if is_first { ">=" } else { ">" }, "<=") }; - let quoted = format!("\"{}\"", column_name); + let quoted = naming::quote_ident(column_name); if oob_squish && is_first && is_last { // Single bin with squish: capture everything "TRUE".to_string() diff --git a/src/plot/scale/scale_type/continuous.rs b/src/plot/scale/scale_type/continuous.rs index 29cec212..8665b0ad 100644 --- a/src/plot/scale/scale_type/continuous.rs +++ b/src/plot/scale/scale_type/continuous.rs @@ -5,6 +5,7 @@ use polars::prelude::DataType; use super::{ ScaleTypeKind, ScaleTypeTrait, TransformKind, OOB_CENSOR, OOB_SQUISH, OOB_VALUES_CONTINUOUS, }; +use crate::naming; use crate::plot::types::{ ArrayConstraint, DefaultParamValue, NumberConstraint, ParamConstraint, ParamDefinition, }; @@ -215,7 +216,7 @@ impl ScaleTypeTrait for Continuous { match oob { OOB_CENSOR => { - let quoted = format!("\"{}\"", column_name); + let quoted = naming::quote_ident(column_name); Some(format!( "(CASE WHEN {} >= {} AND {} <= {} THEN {} ELSE NULL END)", quoted, min, quoted, max, quoted @@ -224,7 +225,7 @@ impl ScaleTypeTrait for Continuous { OOB_SQUISH => { let min_s = min.to_string(); let max_s = max.to_string(); - let quoted = format!("\"{}\"", column_name); + let quoted = naming::quote_ident(column_name); let inner = dialect.sql_least(&[&max_s, "ed]); Some(dialect.sql_greatest(&[&min_s, &inner])) } diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index 8a10aaff..7b60cf04 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -44,11 +44,11 @@ impl super::SqlDialect for DuckDbDialect { fn sql_percentile(&self, column: &str, fraction: f64, from: &str, groups: &[String]) -> String { let group_filter = groups .iter() - .map(|g| format!("AND \"__ggsql_pct__\".\"{g}\" IS NOT DISTINCT FROM \"__ggsql_qt__\".\"{g}\"")) + .map(|g| { let q = crate::naming::quote_ident(g); format!("AND \"__ggsql_pct__\".{q} IS NOT DISTINCT FROM \"__ggsql_qt__\".{q}") }) .collect::>() .join(" "); - let quoted_column = format!("\"{}\"", column); + let quoted_column = crate::naming::quote_ident(column); format!( "(SELECT QUANTILE_CONT({column}, {fraction}) \ FROM ({from}) AS \"__ggsql_pct__\" \ diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 6c04e4d5..da505d8c 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -186,13 +186,13 @@ pub trait SqlDialect { // Uses NTILE(4) to divide data into quartiles, then interpolates between boundaries. let group_filter = groups .iter() - .map(|g| format!("AND \"__ggsql_pct__\".\"{g}\" IS NOT DISTINCT FROM \"__ggsql_qt__\".\"{g}\"")) + .map(|g| { let q = crate::naming::quote_ident(g); format!("AND \"__ggsql_pct__\".{q} IS NOT DISTINCT FROM \"__ggsql_qt__\".{q}") }) .collect::>() .join(" "); let lo_tile = (fraction * 4.0).ceil() as usize; let hi_tile = lo_tile + 1; - let quoted_column = format!("\"{}\"", column); + let quoted_column = crate::naming::quote_ident(column); format!( "(SELECT (\ From a93f3d4e679e996542edca17bc7a86d85485e2a7 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Tue, 31 Mar 2026 15:07:34 +0100 Subject: [PATCH 18/20] cargo fmt --- ggsql-jupyter/src/data_explorer.rs | 43 ++++++++++++++--------------- src/execute/casting.rs | 1 - src/plot/layer/geom/bar.rs | 6 +++- src/plot/layer/geom/density.rs | 21 ++++++++++---- src/plot/scale/scale_type/binned.rs | 22 +++++++++------ src/reader/duckdb.rs | 5 +++- src/reader/mod.rs | 15 ++++++---- src/reader/snowflake.rs | 4 +-- 8 files changed, 69 insertions(+), 48 deletions(-) diff --git a/ggsql-jupyter/src/data_explorer.rs b/ggsql-jupyter/src/data_explorer.rs index a5794141..6ecfbfb4 100644 --- a/ggsql-jupyter/src/data_explorer.rs +++ b/ggsql-jupyter/src/data_explorer.rs @@ -71,7 +71,7 @@ impl DataExplorerState { let table = &path[2]; let table_path = format!( - "\"{}\".\"{}\".\"{}\"" , + "\"{}\".\"{}\".\"{}\"", catalog.replace('"', "\"\""), schema.replace('"', "\"\""), table.replace('"', "\"\""), @@ -252,10 +252,7 @@ impl DataExplorerState { .get("first_index") .and_then(|v| v.as_u64()) .unwrap_or(0) as usize; - let last = spec - .get("last_index") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize; + let last = spec.get("last_index").and_then(|v| v.as_u64()).unwrap_or(0) as usize; (first, last) }) .unwrap_or((0, 0)); @@ -265,16 +262,20 @@ impl DataExplorerState { // Collect requested column indices let col_indices: Vec = selections .iter() - .filter_map(|sel| sel.get("column_index").and_then(|v| v.as_u64()).map(|n| n as usize)) + .filter_map(|sel| { + sel.get("column_index") + .and_then(|v| v.as_u64()) + .map(|n| n as usize) + }) .collect(); // Build column list for SELECT let col_names: Vec = col_indices .iter() .filter_map(|&idx| { - self.columns.get(idx).map(|col| { - format!("\"{}\"", col.name.replace('"', "\"\"")) - }) + self.columns + .get(idx) + .map(|col| format!("\"{}\"", col.name.replace('"', "\"\""))) }) .collect(); @@ -508,13 +509,11 @@ impl DataExplorerState { }) }; - let get_i64 = |name: &str| -> Option { - get_str(name).and_then(|s| s.parse::().ok()) - }; + let get_i64 = + |name: &str| -> Option { get_str(name).and_then(|s| s.parse::().ok()) }; - let get_f64 = |name: &str| -> Option { - get_str(name).and_then(|s| s.parse::().ok()) - }; + let get_f64 = + |name: &str| -> Option { get_str(name).and_then(|s| s.parse::().ok()) }; let mut result = json!({}); @@ -552,8 +551,7 @@ impl DataExplorerState { // DuckDB, NTILE fallback on other backends) let col_name = col.name.replace('"', "\"\""); let from_query = format!("SELECT * FROM {}", self.table_path); - let median_expr = - dialect.sql_percentile(&col_name, 0.5, &from_query, &[]); + let median_expr = dialect.sql_percentile(&col_name, 0.5, &from_query, &[]); let median_sql = format!("SELECT {} AS \"median_val\"", median_expr); if let Ok(median_df) = reader.execute_sql(&median_sql) { if let Some(v) = median_df @@ -825,10 +823,7 @@ impl DataExplorerState { params: &Value, reader: &dyn Reader, ) -> Option { - let limit = params - .get("limit") - .and_then(|v| v.as_u64()) - .unwrap_or(8) as usize; + let limit = params.get("limit").and_then(|v| v.as_u64()).unwrap_or(8) as usize; let quoted_col = format!("\"{}\"", col.name.replace('"', "\"\"")); @@ -975,9 +970,11 @@ fn clean_type_name(type_name: &str) -> String { "FIXED" => { let scale = obj.get("scale").and_then(|v| v.as_i64()).unwrap_or(0); if scale > 0 { - format!("NUMBER({},{})", + format!( + "NUMBER({},{})", obj.get("precision").and_then(|v| v.as_i64()).unwrap_or(38), - scale) + scale + ) } else { "NUMBER".to_string() } diff --git a/src/execute/casting.rs b/src/execute/casting.rs index c87ea5fe..0f731812 100644 --- a/src/execute/casting.rs +++ b/src/execute/casting.rs @@ -184,4 +184,3 @@ pub fn update_type_info_for_casting(type_info: &mut [TypeInfo], requirements: &[ } } } - diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index f62aa5bb..4915844b 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -178,7 +178,11 @@ fn stat_bar_count( if let Some(weight_col) = weight_value.column_name() { if schema_columns.contains(weight_col) { // weight column exists - use SUM (but still call it "count") - format!("SUM({}) AS \"{}\"", naming::quote_ident(weight_col), stat_count) + format!( + "SUM({}) AS \"{}\"", + naming::quote_ident(weight_col), + stat_count + ) } else { // weight mapped but column doesn't exist - fall back to COUNT // (this shouldn't happen with upfront validation, but handle gracefully) diff --git a/src/plot/layer/geom/density.rs b/src/plot/layer/geom/density.rs index cfeaf556..2653e448 100644 --- a/src/plot/layer/geom/density.rs +++ b/src/plot/layer/geom/density.rs @@ -459,8 +459,10 @@ fn build_grid_cte( format!("full_grid.{q} IS NOT DISTINCT FROM bandwidth.{q}") }) .collect(); - let grid_groups_select: Vec = - groups.iter().map(|g| format!("full_grid.{}", naming::quote_ident(g))).collect(); + let grid_groups_select: Vec = groups + .iter() + .map(|g| format!("full_grid.{}", naming::quote_ident(g))) + .collect(); format!( "{seq_cte}, @@ -517,7 +519,10 @@ fn compute_density( } else { group_by .iter() - .map(|g| { let q = naming::quote_ident(g); format!("data.{q} IS NOT DISTINCT FROM bandwidth.{q}") }) + .map(|g| { + let q = naming::quote_ident(g); + format!("data.{q} IS NOT DISTINCT FROM bandwidth.{q}") + }) .collect::>() .join(" AND ") }; @@ -528,7 +533,10 @@ fn compute_density( } else { let grid_data_conds: Vec = group_by .iter() - .map(|g| { let q = naming::quote_ident(g); format!("grid.{q} IS NOT DISTINCT FROM data.{q}") }) + .map(|g| { + let q = naming::quote_ident(g); + format!("grid.{q} IS NOT DISTINCT FROM data.{q}") + }) .collect(); format!("WHERE {}", grid_data_conds.join(" AND ")) }; @@ -542,7 +550,10 @@ fn compute_density( ); // Build group-related SQL fragments - let grid_groups: Vec = group_by.iter().map(|g| format!("grid.{}", naming::quote_ident(g))).collect(); + let grid_groups: Vec = group_by + .iter() + .map(|g| format!("grid.{}", naming::quote_ident(g))) + .collect(); let aggregation = format!( "GROUP BY grid.x{grid_group_by} ORDER BY grid.x{grid_group_by}", diff --git a/src/plot/scale/scale_type/binned.rs b/src/plot/scale/scale_type/binned.rs index 8bdeca83..b491ca1f 100644 --- a/src/plot/scale/scale_type/binned.rs +++ b/src/plot/scale/scale_type/binned.rs @@ -1229,7 +1229,10 @@ mod tests { !sql.contains("CAST("), "SQL should not contain CAST when column is numeric" ); - assert!(sql.contains("\"value\" >= 0"), "SQL should use quoted column name"); + assert!( + sql.contains("\"value\" >= 0"), + "SQL should use quoted column name" + ); } #[test] @@ -1505,9 +1508,9 @@ mod tests { "left", vec![0.0, 10.0, 20.0, 30.0], vec![ - "WHEN \"value\" < 10 THEN 5", // First bin extends to -∞ - "WHEN \"value\" >= 10 AND \"value\" < 20 THEN 15", // Middle bin - "WHEN \"value\" >= 20 THEN 25", // Last bin extends to +∞ + "WHEN \"value\" < 10 THEN 5", // First bin extends to -∞ + "WHEN \"value\" >= 10 AND \"value\" < 20 THEN 15", // Middle bin + "WHEN \"value\" >= 20 THEN 25", // Last bin extends to +∞ ], ), // closed="right" with 3 bins (4 breaks) @@ -1515,9 +1518,9 @@ mod tests { "right", vec![0.0, 10.0, 20.0, 30.0], vec![ - "WHEN \"value\" <= 10 THEN 5", // First bin extends to -∞ - "WHEN \"value\" > 10 AND \"value\" <= 20 THEN 15", // Middle bin - "WHEN \"value\" > 20 THEN 25", // Last bin extends to +∞ + "WHEN \"value\" <= 10 THEN 5", // First bin extends to -∞ + "WHEN \"value\" > 10 AND \"value\" <= 20 THEN 15", // Middle bin + "WHEN \"value\" > 20 THEN 25", // Last bin extends to +∞ ], ), ]; @@ -1651,7 +1654,10 @@ mod tests { ), ( false, - vec!["\"col\" >= 0 AND \"col\" < 10", "\"col\" >= 10 AND \"col\" <= 20"], + vec![ + "\"col\" >= 0 AND \"col\" < 10", + "\"col\" >= 10 AND \"col\" <= 20", + ], ), ]; diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index 7b60cf04..beb8e7f1 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -44,7 +44,10 @@ impl super::SqlDialect for DuckDbDialect { fn sql_percentile(&self, column: &str, fraction: f64, from: &str, groups: &[String]) -> String { let group_filter = groups .iter() - .map(|g| { let q = crate::naming::quote_ident(g); format!("AND \"__ggsql_pct__\".{q} IS NOT DISTINCT FROM \"__ggsql_qt__\".{q}") }) + .map(|g| { + let q = crate::naming::quote_ident(g); + format!("AND \"__ggsql_pct__\".{q} IS NOT DISTINCT FROM \"__ggsql_qt__\".{q}") + }) .collect::>() .join(" "); diff --git a/src/reader/mod.rs b/src/reader/mod.rs index da505d8c..c9b03464 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -100,8 +100,7 @@ pub trait SqlDialect { /// SQL to list catalog names. Returns rows with column `catalog_name`. fn sql_list_catalogs(&self) -> String { - "SELECT DISTINCT catalog_name FROM information_schema.schemata ORDER BY catalog_name" - .into() + "SELECT DISTINCT catalog_name FROM information_schema.schemata ORDER BY catalog_name".into() } /// SQL to list schema names within a catalog. Returns rows with column `schema_name`. @@ -186,7 +185,10 @@ pub trait SqlDialect { // Uses NTILE(4) to divide data into quartiles, then interpolates between boundaries. let group_filter = groups .iter() - .map(|g| { let q = crate::naming::quote_ident(g); format!("AND \"__ggsql_pct__\".{q} IS NOT DISTINCT FROM \"__ggsql_qt__\".{q}") }) + .map(|g| { + let q = crate::naming::quote_ident(g); + format!("AND \"__ggsql_pct__\".{q} IS NOT DISTINCT FROM \"__ggsql_qt__\".{q}") + }) .collect::>() .join(" "); @@ -474,9 +476,10 @@ pub fn execute_with_reader(reader: &dyn Reader, query: &str) -> Result { let prepared_data = prepare_data_with_reader(query, reader)?; - let plot = prepared_data.specs.into_iter().next().ok_or_else(|| { - GgsqlError::ValidationError("No visualization spec found".to_string()) - })?; + let plot = + prepared_data.specs.into_iter().next().ok_or_else(|| { + GgsqlError::ValidationError("No visualization spec found".to_string()) + })?; let layer_sql = vec![None; plot.layers.len()]; let stat_sql = vec![None; plot.layers.len()]; diff --git a/src/reader/snowflake.rs b/src/reader/snowflake.rs index 8752b8ad..33d5dd9e 100644 --- a/src/reader/snowflake.rs +++ b/src/reader/snowflake.rs @@ -25,8 +25,6 @@ impl super::SqlDialect for SnowflakeDialect { let catalog_ident = catalog.replace('"', "\"\""); let schema_ident = schema.replace('"', "\"\""); let table_ident = table.replace('"', "\"\""); - format!( - "SHOW COLUMNS IN TABLE \"{catalog_ident}\".\"{schema_ident}\".\"{table_ident}\"" - ) + format!("SHOW COLUMNS IN TABLE \"{catalog_ident}\".\"{schema_ident}\".\"{table_ident}\"") } } From 27a65204709764c6872acc93286c2a6a4a12d5c6 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Tue, 31 Mar 2026 16:06:47 +0100 Subject: [PATCH 19/20] Keep clippy happy --- src/reader/odbc.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/reader/odbc.rs b/src/reader/odbc.rs index d6f47742..d5b6339d 100644 --- a/src/reader/odbc.rs +++ b/src/reader/odbc.rs @@ -322,14 +322,14 @@ fn cursor_to_dataframe(mut cursor: impl Cursor) -> Result { .map_err(|e| GgsqlError::ReaderError(format!("Failed to fetch batch: {}", e)))? { let num_rows = batch.num_rows(); - for col_idx in 0..col_count { + for (col_idx, column) in columns.iter_mut().enumerate() { for row_idx in 0..num_rows { let value = batch .at_as_str(col_idx, row_idx) .ok() .flatten() .map(|s| s.to_string()); - columns[col_idx].push(value); + column.push(value); } } } From 2317ca82472da8abdad6664059481f2789bf3b32 Mon Sep 17 00:00:00 2001 From: George Stagg Date: Tue, 31 Mar 2026 16:21:39 +0100 Subject: [PATCH 20/20] Install ODBC in GHA workflows --- .github/workflows/R-CMD-check.yaml | 3 +++ .github/workflows/build.yaml | 3 +++ .github/workflows/publish.yaml | 3 +++ 3 files changed, 9 insertions(+) diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 184adfb8..08874b95 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -28,6 +28,9 @@ jobs: - name: Install tree-sitter-cli run: npm install -g tree-sitter-cli + - name: Install ODBC + run: sudo apt-get install -y unixodbc-dev + - name: Install Rust uses: dtolnay/rust-toolchain@stable diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 0937dc62..7d1aef25 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -32,6 +32,9 @@ jobs: - name: Install LLVM run: sudo apt-get install -y llvm + - name: Install ODBC + run: sudo apt-get install -y unixodbc-dev + - name: Install Rust uses: dtolnay/rust-toolchain@stable diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 5d85746a..a64c898b 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -33,6 +33,9 @@ jobs: - name: Install LLVM run: sudo apt-get install -y llvm + - name: Install ODBC + run: sudo apt-get install -y unixodbc-dev + - name: Install Rust uses: dtolnay/rust-toolchain@stable