diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 184adfb8..08874b95 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -28,6 +28,9 @@ jobs: - name: Install tree-sitter-cli run: npm install -g tree-sitter-cli + - name: Install ODBC + run: sudo apt-get install -y unixodbc-dev + - name: Install Rust uses: dtolnay/rust-toolchain@stable diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 0937dc62..7d1aef25 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -32,6 +32,9 @@ jobs: - name: Install LLVM run: sudo apt-get install -y llvm + - name: Install ODBC + run: sudo apt-get install -y unixodbc-dev + - name: Install Rust uses: dtolnay/rust-toolchain@stable diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 5d85746a..a64c898b 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -33,6 +33,9 @@ jobs: - name: Install LLVM run: sudo apt-get install -y llvm + - name: Install ODBC + run: sudo apt-get install -y unixodbc-dev + - name: Install Rust uses: dtolnay/rust-toolchain@stable diff --git a/Cargo.lock b/Cargo.lock index b4828bb5..5f819190 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -64,6 +64,31 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "android-activity" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f2a1bb052857d5dd49572219344a7332b31b76405648eabac5bc68978251bcd" +dependencies = [ + "android-properties", + "bitflags 2.11.0", + "cc", + "jni 0.22.4", + "libc", + "log", + "ndk", + "ndk-context", + "ndk-sys", + "num_enum", + "thiserror 2.0.18", +] + +[[package]] +name = "android-properties" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7eb209b1518d6bb87b283c20095f5228ecda460da70b44f0802523dea6da04" + [[package]] name = "android_system_properties" version = "0.1.5" @@ -612,6 +637,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block2" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c132eebf10f5cad5289222520a4a058514204aed6d791f1cf4fe8088b82d15f" +dependencies = [ + "objc2", +] + [[package]] name = "borrow-or-share" version = "0.2.4" @@ -743,6 +777,20 @@ dependencies = [ "serde", ] +[[package]] +name = "calloop" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b99da2f8558ca23c71f4fd15dc57c906239752dd27ff3c00a1d56b685b7cbfec" +dependencies = [ + "bitflags 2.11.0", + "log", + "polling", + "rustix 0.38.44", + "slab", + "thiserror 1.0.69", +] + [[package]] name = "cast" version = "0.3.0" @@ -1136,6 +1184,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "cursor-icon" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f27ae1dd37df86211c42e150270f82743308803d90a6f6e6651cd730d5e1732f" + [[package]] name = "dashmap" version = "5.5.3" @@ -1204,6 +1258,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "dispatch" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd0c93bb4b0c6d9b77f4435b0ae98c24d17f1c45b2ff844c6151a07256ca923b" + [[package]] name = "displaydoc" version = "0.2.5" @@ -1224,6 +1284,12 @@ dependencies = [ "libloading", ] +[[package]] +name = "dpi" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b14ccef22fc6f5a8f4d7d768562a182c04ce9a3b3157b91390b52ddfdf1a76" + [[package]] name = "duckdb" version = "1.4.4" @@ -1550,7 +1616,7 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4" dependencies = [ - "rustix", + "rustix 1.1.4", "windows-sys 0.59.0", ] @@ -1717,6 +1783,7 @@ dependencies = [ "csscolorparser", "duckdb", "jsonschema", + "odbc-api", "palette", "plotters", "polars", @@ -1730,8 +1797,10 @@ dependencies = [ "serde", "serde_json", "sprintf", + "tempfile", "thiserror 1.0.69", "tokio", + "toml_edit 0.22.27", "tower-http 0.5.2", "tracing", "tracing-subscriber", @@ -1911,6 +1980,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "hex" version = "0.4.3" @@ -2263,19 +2338,68 @@ dependencies = [ "cesu8", "cfg-if", "combine", - "jni-sys", + "jni-sys 0.3.0", "log", "thiserror 1.0.69", "walkdir", "windows-sys 0.45.0", ] +[[package]] +name = "jni" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efd9a482cf3a427f00d6b35f14332adc7902ce91efb778580e180ff90fa3498" +dependencies = [ + "cfg-if", + "combine", + "jni-macros", + "jni-sys 0.4.1", + "log", + "simd_cesu8", + "thiserror 2.0.18", + "walkdir", + "windows-link", +] + +[[package]] +name = "jni-macros" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00109accc170f0bdb141fed3e393c565b6f5e072365c3bd58f5b062591560a3" +dependencies = [ + "proc-macro2", + "quote", + "rustc_version", + "simd_cesu8", + "syn 2.0.117", +] + [[package]] name = "jni-sys" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +[[package]] +name = "jni-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn 2.0.117", +] + [[package]] name = "jobserver" version = "0.1.34" @@ -2462,6 +2586,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + [[package]] name = "linux-raw-sys" version = "0.12.1" @@ -2590,6 +2720,36 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "ndk" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3f42e7bbe13d351b6bead8286a43aac9534b82bd3cc43e47037f012ebfd62d4" +dependencies = [ + "bitflags 2.11.0", + "jni-sys 0.3.0", + "log", + "ndk-sys", + "num_enum", + "raw-window-handle", + "thiserror 1.0.69", +] + +[[package]] +name = "ndk-context" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" + +[[package]] +name = "ndk-sys" +version = "0.6.0+11769913" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee6cda3051665f1fb8d9e08fc35c96d5a244fb1be711a03b71118828afc9a873" +dependencies = [ + "jni-sys 0.3.0", +] + [[package]] name = "now" version = "0.1.3" @@ -2688,6 +2848,96 @@ dependencies = [ "libm", ] +[[package]] +name = "num_enum" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0bca838442ec211fa11de3a8b0e0e8f3a4522575b5c4c06ed722e005036f26" +dependencies = [ + "num_enum_derive", + "rustversion", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "680998035259dcfcafe653688bf2aa6d3e2dc05e98be6ab46afb089dc84f1df8" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "objc-sys" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdb91bdd390c7ce1a8607f35f3ca7151b65afc0ff5ff3b34fa350f7d7c7e4310" + +[[package]] +name = "objc2" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46a785d4eeff09c14c487497c162e92766fbb3e4059a71840cecc03d9a50b804" +dependencies = [ + "objc-sys", + "objc2-encode", +] + +[[package]] +name = "objc2-app-kit" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4e89ad9e3d7d297152b17d39ed92cd50ca8063a89a9fa569046d41568891eff" +dependencies = [ + "bitflags 2.11.0", + "block2", + "libc", + "objc2", + "objc2-core-data", + "objc2-core-image", + "objc2-foundation", + "objc2-quartz-core", +] + +[[package]] +name = "objc2-cloud-kit" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74dd3b56391c7a0596a295029734d3c1c5e7e510a4cb30245f8221ccea96b009" +dependencies = [ + "bitflags 2.11.0", + "block2", + "objc2", + "objc2-core-location", + "objc2-foundation", +] + +[[package]] +name = "objc2-contacts" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5ff520e9c33812fd374d8deecef01d4a840e7b41862d849513de77e44aa4889" +dependencies = [ + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-core-data" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef" +dependencies = [ + "bitflags 2.11.0", + "block2", + "objc2", + "objc2-foundation", +] + [[package]] name = "objc2-core-foundation" version = "0.3.2" @@ -2697,6 +2947,96 @@ dependencies = [ "bitflags 2.11.0", ] +[[package]] +name = "objc2-core-image" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55260963a527c99f1819c4f8e3b47fe04f9650694ef348ffd2227e8196d34c80" +dependencies = [ + "block2", + "objc2", + "objc2-foundation", + "objc2-metal", +] + +[[package]] +name = "objc2-core-location" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "000cfee34e683244f284252ee206a27953279d370e309649dc3ee317b37e5781" +dependencies = [ + "block2", + "objc2", + "objc2-contacts", + "objc2-foundation", +] + +[[package]] +name = "objc2-encode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" + +[[package]] +name = "objc2-foundation" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ee638a5da3799329310ad4cfa62fbf045d5f56e3ef5ba4149e7452dcf89d5a8" +dependencies = [ + "bitflags 2.11.0", + "block2", + "dispatch", + "libc", + "objc2", +] + +[[package]] +name = "objc2-link-presentation" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1a1ae721c5e35be65f01a03b6d2ac13a54cb4fa70d8a5da293d7b0020261398" +dependencies = [ + "block2", + "objc2", + "objc2-app-kit", + "objc2-foundation", +] + +[[package]] +name = "objc2-metal" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6" +dependencies = [ + "bitflags 2.11.0", + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-quartz-core" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a" +dependencies = [ + "bitflags 2.11.0", + "block2", + "objc2", + "objc2-foundation", + "objc2-metal", +] + +[[package]] +name = "objc2-symbols" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a684efe3dec1b305badae1a28f6555f6ddd3bb2c2267896782858d5a78404dc" +dependencies = [ + "objc2", + "objc2-foundation", +] + [[package]] name = "objc2-system-configuration" version = "0.3.2" @@ -2706,6 +3046,51 @@ dependencies = [ "objc2-core-foundation", ] +[[package]] +name = "objc2-ui-kit" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8bb46798b20cd6b91cbd113524c490f1686f4c4e8f49502431415f3512e2b6f" +dependencies = [ + "bitflags 2.11.0", + "block2", + "objc2", + "objc2-cloud-kit", + "objc2-core-data", + "objc2-core-image", + "objc2-core-location", + "objc2-foundation", + "objc2-link-presentation", + "objc2-quartz-core", + "objc2-symbols", + "objc2-uniform-type-identifiers", + "objc2-user-notifications", +] + +[[package]] +name = "objc2-uniform-type-identifiers" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44fa5f9748dbfe1ca6c0b79ad20725a11eca7c2218bceb4b005cb1be26273bfe" +dependencies = [ + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-user-notifications" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76cfcbf642358e8689af64cee815d139339f3ed8ad05103ed5eaf73db8d84cb3" +dependencies = [ + "bitflags 2.11.0", + "block2", + "objc2", + "objc2-core-location", + "objc2-foundation", +] + [[package]] name = "object" version = "0.37.3" @@ -2750,6 +3135,26 @@ dependencies = [ "web-time", ] +[[package]] +name = "odbc-api" +version = "13.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44e14665455e2817ac5b0dd9f65a3dc97e76e8f85eeac6e4301b7cf9da451884" +dependencies = [ + "atoi", + "log", + "odbc-sys", + "thiserror 2.0.18", + "widestring", + "winit", +] + +[[package]] +name = "odbc-sys" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ecdb20f7c165083ad1bc9f55122f677725e257716a5bc83e5413d5654b7d6f1" + [[package]] name = "once_cell" version = "1.21.4" @@ -2774,6 +3179,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "orbclient" +version = "0.3.51" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59aed3b33578edcfa1bc96a321d590d31832b6ad55a26f0313362ce687e9abd6" +dependencies = [ + "libc", + "libredox", +] + [[package]] name = "outref" version = "0.5.2" @@ -2963,6 +3378,26 @@ dependencies = [ "uncased", ] +[[package]] +name = "pin-project" +version = "1.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "pin-project-lite" version = "0.2.17" @@ -3596,6 +4031,20 @@ dependencies = [ "version_check", ] +[[package]] +name = "polling" +version = "3.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218" +dependencies = [ + "cfg-if", + "concurrent-queue", + "hermit-abi", + "pin-project-lite", + "rustix 1.1.4", + "windows-sys 0.61.2", +] + [[package]] name = "portable-atomic" version = "1.13.1" @@ -3679,7 +4128,7 @@ version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" dependencies = [ - "toml_edit", + "toml_edit 0.25.4+spec-1.1.0", ] [[package]] @@ -3986,6 +4435,12 @@ dependencies = [ "bitflags 2.11.0", ] +[[package]] +name = "raw-window-handle" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" + [[package]] name = "rayon" version = "1.11.0" @@ -4026,6 +4481,15 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -4330,6 +4794,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags 2.11.0", + "errno", + "libc", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + [[package]] name = "rustix" version = "1.1.4" @@ -4339,7 +4816,7 @@ dependencies = [ "bitflags 2.11.0", "errno", "libc", - "linux-raw-sys", + "linux-raw-sys 0.12.1", "windows-sys 0.61.2", ] @@ -4389,7 +4866,7 @@ checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" dependencies = [ "core-foundation 0.10.1", "core-foundation-sys", - "jni", + "jni 0.21.1", "log", "once_cell", "rustls", @@ -4649,6 +5126,16 @@ dependencies = [ "value-trait", ] +[[package]] +name = "simd_cesu8" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94f90157bb87cddf702797c5dadfa0be7d266cdf49e22da2fcaa32eff75b2c33" +dependencies = [ + "rustc_version", + "simdutf8", +] + [[package]] name = "simdutf8" version = "0.1.5" @@ -4682,6 +5169,15 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "smol_str" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd538fb6910ac1099850255cf94a94df6551fbdd602454387d0adb2d1ca6dead" +dependencies = [ + "serde", +] + [[package]] name = "snap" version = "1.1.1" @@ -4911,7 +5407,7 @@ dependencies = [ "fastrand", "getrandom 0.4.2", "once_cell", - "rustix", + "rustix 1.1.4", "windows-sys 0.61.2", ] @@ -5098,6 +5594,12 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" + [[package]] name = "toml_datetime" version = "1.0.0+spec-1.1.0" @@ -5107,6 +5609,18 @@ dependencies = [ "serde_core", ] +[[package]] +name = "toml_edit" +version = "0.22.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +dependencies = [ + "indexmap", + "toml_datetime 0.6.11", + "toml_write", + "winnow", +] + [[package]] name = "toml_edit" version = "0.25.4+spec-1.1.0" @@ -5114,7 +5628,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7193cbd0ce53dc966037f54351dbbcf0d5a642c7f0038c382ef9e677ce8c13f2" dependencies = [ "indexmap", - "toml_datetime", + "toml_datetime 1.0.0+spec-1.1.0", "toml_parser", "winnow", ] @@ -5128,6 +5642,12 @@ dependencies = [ "winnow", ] +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + [[package]] name = "tower" version = "0.5.3" @@ -5748,6 +6268,12 @@ dependencies = [ "web-sys", ] +[[package]] +name = "widestring" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72069c3113ab32ab29e5584db3c6ec55d416895e60715417b5b883a357c3e471" + [[package]] name = "winapi" version = "0.3.9" @@ -6069,6 +6595,46 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" +[[package]] +name = "winit" +version = "0.30.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6755fa58a9f8350bd1e472d4c3fcc25f824ec358933bba33306d0b63df5978d" +dependencies = [ + "android-activity", + "atomic-waker", + "bitflags 2.11.0", + "block2", + "calloop", + "cfg_aliases", + "concurrent-queue", + "core-foundation 0.9.4", + "core-graphics", + "cursor-icon", + "dpi", + "js-sys", + "libc", + "ndk", + "objc2", + "objc2-app-kit", + "objc2-foundation", + "objc2-ui-kit", + "orbclient", + "pin-project", + "raw-window-handle", + "redox_syscall 0.4.1", + "rustix 0.38.44", + "smol_str", + "tracing", + "unicode-segmentation", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "web-time", + "windows-sys 0.52.0", + "xkbcommon-dl", +] + [[package]] name = "winnow" version = "0.7.15" @@ -6197,9 +6763,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" dependencies = [ "libc", - "rustix", + "rustix 1.1.4", +] + +[[package]] +name = "xkbcommon-dl" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039de8032a9a8856a6be89cea3e5d12fdd82306ab7c94d74e6deab2460651c5" +dependencies = [ + "bitflags 2.11.0", + "dlib", + "log", + "once_cell", + "xkeysym", ] +[[package]] +name = "xkeysym" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9cc00251562a284751c9973bace760d86c0276c471b4be569fe6b068ee97a56" + [[package]] name = "xxhash-rust" version = "0.8.15" diff --git a/Cargo.toml b/Cargo.toml index 9b4f3eea..751c07a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,10 @@ arrow = { version = "56", default-features = false, features = ["ipc"] } postgres = "0.19" rusqlite = { version = "0.38", features = ["bundled", "chrono", "functions", "window"] } +# ODBC +odbc-api = "13" +toml_edit = "0.22" + # Writers plotters = "0.3" diff --git a/ODBC.md b/ODBC.md new file mode 100644 index 00000000..51cdcd50 --- /dev/null +++ b/ODBC.md @@ -0,0 +1,312 @@ +# Positron Connections Pane Integration for ggsql + +## Context + +ggsql's Jupyter kernel (`ggsql-jupyter`) and VS Code extension (`ggsql-vscode`) currently have no integration with Positron's Connections pane. The kernel is hardcoded to `duckdb://memory` with no way to configure the database connection. This plan adds: + +1. **Connection comm protocol** in the kernel — so database schemas appear in the Connections pane +2. **Connection drivers** in the extension — so users can create connections via the "New Connection" dialog +3. **Generic ODBC reader** in the library — supporting Snowflake (with Workbench credentials), PostgreSQL, SQL Server, etc. +4. **Dynamic connection switching** in the kernel — via meta-commands executed from the connection dialog + +## Architecture Overview + +``` +New Connection Dialog ggsql-jupyter Kernel + (ggsql-vscode) (Rust) + │ │ + │ generateCode() → │ + │ "-- @connect: odbc://snowflake?…" │ + │ │ + │ connect() → │ + │ positron.runtime.executeCode()───┤ + │ │ detect meta-command + │ │ create OdbcReader + │ │ open positron.connection comm + │ │ + │ ◄── comm_open ──────┤ (kernel initiates) + │ │ + Connections Pane │ + │── list_objects([]) ─────────────►│ SELECT … FROM information_schema + │◄─ [{name:"public",kind:"schema"}]│ + │── list_fields([schema,table]) ──►│ SELECT … FROM information_schema.columns + │◄─ [{name:"id",dtype:"integer"}] │ +``` + +Key insight: The **kernel** opens the `positron.connection` comm (backend-initiated), unlike variables/ui comms which are frontend-initiated. + +--- + +## Part 1: ODBC Reader (`src/reader/odbc.rs`) + +### New file: `src/reader/odbc.rs` + +Generic ODBC reader using `odbc-api` crate. Implements `Reader` trait. + +**Connection string format**: `odbc://` prefix + raw ODBC connection string (no URI parsing of the payload) +- `odbc://Driver=Snowflake;Server=myaccount.snowflakecomputing.com;Warehouse=WH` — Snowflake +- `odbc://Driver={PostgreSQL};Server=localhost;Database=mydb` — PostgreSQL +- The extension's driver dialogs build the ODBC string in `generateCode()` and prefix it with `odbc://` +- Parsing: strip `odbc://` prefix, pass remainder directly to `SQLDriverConnect` + +**Core implementation**: +- `OdbcReader::from_connection_string(uri)` — parse URI, detect credentials, connect +- `execute_sql(&self, sql)` — execute via `connection.execute()`, convert cursor → DataFrame +- Cursor → DataFrame conversion: iterate ODBC columnar buffers, map ODBC types to Polars types +- `register()` returns error (ODBC doesn't support temp table registration easily) +- `dialect()` returns dialect variant detected from DBMS info + +**Snowflake Workbench credential detection** (per `~/work/positron/CONNECTIONS.md`): +When `OdbcReader` sees `Driver=Snowflake` in the connection string and no `Token=` is present: +1. Read `SNOWFLAKE_HOME` env var +2. If path contains `"posit-workbench"`, parse `$SNOWFLAKE_HOME/connections.toml` +3. Extract `account` + `token` from `[workbench]` section +4. Inject `Authenticator=oauth;Token=` into the connection string before connecting +5. If no Workbench credentials found, connect as-is (user may have specified auth in the string) + +**Credential storage**: Trust Positron's secret storage — the full `-- @connect:` meta-command (including any credentials in the ODBC string) is stored in the `code` field of the connection comm metadata. Positron persists this in encrypted workspace secret storage for reconnection. + +**OdbcDialect**: Implements `SqlDialect` with a variant enum (Generic, Snowflake, PostgreSQL) detected from DBMS metadata at connection time. + +### Modify: `src/reader/mod.rs` +- Add `#[cfg(feature = "odbc")] pub mod odbc;` and re-export `OdbcReader` +- Remove `where Self: Sized` bound from `fn execute()` + +### Modify: `src/reader/connection.rs` +- Add `ODBC(String)` variant to `ConnectionInfo` enum +- Parse `odbc://` prefix in `parse_connection_string()` + +### Modify: `src/execute/mod.rs` +- Change `prepare_data_with_reader(query: &str, reader: &R)` → `prepare_data_with_reader(query: &str, reader: &dyn Reader)` +- This is safe: all methods called on reader (`execute_sql`, `dialect`, `register`, `unregister`) are object-safe. `materialize_ctes` already takes `&dyn Reader`. + +### Modify: `src/Cargo.toml` +- Add feature: `odbc = ["dep:odbc-api", "dep:toml"]` +- Add dependencies: `odbc-api = { version = "21", optional = true }`, `toml = { version = "0.8", optional = true }` +- Add `"odbc"` to `all-readers` feature list + +--- + +## Part 2: Kernel Connection Comm Protocol (`ggsql-jupyter/`) + +### New file: `ggsql-jupyter/src/connection.rs` + +Module for database schema introspection via the reader. All methods query `information_schema` using `reader.execute_sql()`. + +**Methods**: +- `list_objects(reader, path) -> Vec`: + - Depth depends on dialect — `SqlDialect::has_catalogs()` (true for Snowflake, false for DuckDB/Postgres) + - **Without catalogs** (DuckDB, PostgreSQL): + - `[]` → query `information_schema.schemata` → return schemas + - `[schema]` → query `information_schema.tables WHERE table_schema = ''` → return tables/views + - **With catalogs** (Snowflake): + - `[]` → query `SHOW DATABASES` or `information_schema.schemata` grouped by catalog → return catalogs with `kind = "catalog"` + - `[catalog]` → query `information_schema.schemata WHERE catalog_name = ''` → return schemas + - `[catalog, schema]` → query `information_schema.tables WHERE table_catalog = '' AND table_schema = ''` → return tables/views +- `list_fields(reader, path) -> Vec`: + - **Without catalogs**: `[schema, table]` → query `information_schema.columns` + - **With catalogs**: `[catalog, schema, table]` → query `information_schema.columns` with catalog filter +- `contains_data(path) -> bool`: true when last element has `kind` == "table" or "view" +- **SQL safety**: All interpolated identifiers use standard quote-escaping (`'` → `''`) via a shared `escape_sql_string()` helper +- `get_icon(path) -> String`: return empty string (let Positron use defaults) +- `preview_object(path)`: stub — return null (Data Explorer comm is a separate future feature) +- `get_metadata(reader_uri, name) -> MetadataSchema`: return connection metadata + +**Dialect differences**: DuckDB's default schema is `main` (not `public`). Snowflake has a catalog→schema→table hierarchy. The `SqlDialect` trait gets new optional methods: +- `has_catalogs() -> bool` — false by default, true for Snowflake +- `schema_list_query() -> &str` — override for backends that don't support `information_schema.schemata` +- `default_schema() -> &str` — `"main"` for DuckDB, `"public"` for PostgreSQL, etc. + +### Modify: `ggsql-jupyter/src/kernel.rs` + +**Add connection comm tracking**: +```rust +connection_comm_id: Option, +``` + +**Opening the comm** (kernel-initiated, sent on iopub after a successful `-- @connect:`): +```rust +// Send comm_open on iopub with target_name = "positron.connection" +self.send_iopub("comm_open", json!({ + "comm_id": new_uuid, + "target_name": "positron.connection", + "data": { + "name": display_name, // e.g. "DuckDB (memory)" or "Snowflake (myaccount)" + "language_id": "ggsql", + "host": host, // e.g. "memory" or "myaccount.snowflakecomputing.com" + "type": type_name, // e.g. "DuckDB" or "Snowflake" + "code": meta_command // e.g. "-- @connect: duckdb://memory" + } +}), parent).await?; +``` + +**Handle incoming comm_msg** for connection comm (JSON-RPC methods from Positron): +- Route `list_objects`, `list_fields`, `contains_data`, `get_icon`, `get_metadata` to `connection.rs` functions +- `preview_object`: stub — return null (full Data Explorer comm is a separate future feature) +- Send JSON-RPC responses back on shell via `send_shell_reply("comm_msg", ...)` + +**Handle comm_close**: clear `connection_comm_id` + +**Update comm_info_request**: include connection comm in response + +**Open connection comm on startup**: After kernel initializes with default DuckDB reader, automatically open a `positron.connection` comm so the Connections pane shows the default database immediately. +- Use `create_message(..., None)` for the no-parent startup comm_open (same pattern as `send_status_initial` at kernel.rs:749) +- Add `send_iopub_no_parent()` helper (or generalize `send_iopub` to accept `Option<&JupyterMessage>`) + +**Replacing connection comms on reader switch**: When `-- @connect:` switches readers: +1. If `connection_comm_id` is `Some`, send `comm_close` on iopub for the old comm ID first +2. Clear `connection_comm_id` +3. Open a new comm with a fresh UUID +4. This ensures no stale comm IDs linger and the Connections pane sees the old connection as disconnected + the new one as active + +### Modify: `ggsql-jupyter/src/executor.rs` + +**Make reader swappable**: +- Change `reader: DuckDBReader` → `reader: Box` +- Add `pub fn swap_reader(&mut self, new_reader: Box)` +- Add `pub fn reader(&self) -> &dyn Reader` accessor for connection.rs queries +- For visualization execution: call `self.reader.execute(query)` directly on the `Box` + +**Add meta-command handling**: +- In `execute()`, check if code starts with `-- @connect: ` +- Parse the connection URI from the meta-command +- Call `create_reader(uri)` (shared function) to build the new reader +- Swap the reader via `swap_reader()` +- Return a new `ExecutionResult::ConnectionChanged { uri, display_name }` variant + +**New `create_reader(uri)` function** (in executor.rs or a new `reader_factory.rs`): +- Parse connection string using `ggsql::reader::connection::parse_connection_string()` +- Match on `ConnectionInfo` variant to construct appropriate reader +- Feature-gated: DuckDB (default), SQLite (optional), ODBC (optional) + +### Modify: `ggsql-jupyter/src/main.rs` + +- Add `--reader` CLI arg (default: `"duckdb://memory"`) +- Pass the reader URI to `KernelServer::new(connection, reader_uri)` +- Kernel creates initial reader from this URI + +### Modify: `ggsql-jupyter/src/lib.rs` + +- Add `mod connection;` + +### Modify: `ggsql-jupyter/Cargo.toml` + +Current: `ggsql = { workspace = true, features = ["duckdb", "vegalite"] }` — only DuckDB. + +Add feature flags: +```toml +[features] +default = [] +sqlite = ["ggsql/sqlite"] +odbc = ["ggsql/odbc"] +all-readers = ["sqlite", "odbc"] +``` + +Update ggsql dep: `ggsql = { workspace = true, features = ["duckdb", "vegalite"] }` stays as default (DuckDB always available). + +**`create_reader()` runtime error handling**: When `-- @connect:` requests a reader that isn't compiled in, return a clear error message to the user via execute_reply: +``` +Error: SQLite support is not compiled into this ggsql-jupyter binary. +Rebuild with: cargo build --features sqlite +``` +This uses `#[cfg(feature = "...")]` branches with a fallback error arm per reader type. + +--- + +## Part 3: VS Code Extension Connection Drivers (`ggsql-vscode/`) + +### New file: `ggsql-vscode/src/connections.ts` + +**`createConnectionDrivers(): positron.ConnectionsDriver[]`** + +Returns array of drivers to register. Each driver: +- `generateCode(inputs)` → returns `-- @connect: ` meta-command string +- `connect(code)` → calls `positron.runtime.executeCode('ggsql', code, false)` to send the meta-command to the running kernel + +**DuckDB driver** (`driverId: 'ggsql-duckdb'`): +- Inputs: `database` (string, optional — empty = in-memory) +- generateCode: `-- @connect: duckdb://memory` or `-- @connect: duckdb://` + +**Snowflake driver** (`driverId: 'ggsql-snowflake'`): +- Inputs: `account` (string, required), `warehouse` (string, required), `database` (string, optional), `schema` (string, optional) +- generateCode: builds full ODBC string e.g. `-- @connect: odbc://Driver=Snowflake;Server=.snowflakecomputing.com;Warehouse=` + +**Generic ODBC driver** (`driverId: 'ggsql-odbc'`): +- Inputs: `connection_string` (string, required — raw ODBC connection string) +- generateCode: `-- @connect: odbc://` + +### Modify: `ggsql-vscode/src/extension.ts` + +In `activate()`, after registering the runtime manager: +```typescript +import { createConnectionDrivers } from './connections'; +// ... +const drivers = createConnectionDrivers(); +for (const driver of drivers) { + context.subscriptions.push(positronApi.connections.registerConnectionDriver(driver)); +} +``` + +### Modify: `ggsql-vscode/src/manager.ts` + +- Update `createKernelSpec()` to accept optional `readerUri` parameter +- Pass `--reader ` in spawn args when `readerUri` is provided +- Add `getActiveSession()` method so `connect()` can check if a kernel is running + +--- + +## Implementation Order + +### Phase 1: Kernel meta-commands and dynamic reader switching +1. Modify `executor.rs` — make reader swappable, add meta-command detection, add `create_reader()` +2. Modify `main.rs` — add `--reader` CLI arg, pass to executor +3. Modify `kernel.rs` — handle `ConnectionChanged` result from executor +4. Test: start kernel with `--reader duckdb://memory`, verify meta-command works + +### Phase 2: Connection comm protocol +5. Create `connection.rs` — schema introspection via information_schema +6. Modify `kernel.rs` — open `positron.connection` comm on startup and after `-- @connect:`, handle incoming JSON-RPC methods +7. Test: start kernel in Positron, verify Connections pane shows DuckDB schema + +### Phase 3: ODBC reader +8. Create `src/reader/odbc.rs` — generic ODBC reader with cursor→DataFrame conversion +9. Add Workbench Snowflake credential detection +10. Modify `connection.rs`, `mod.rs`, `Cargo.toml` for ODBC feature +11. Test: connect to local ODBC data source, verify queries work + +### Phase 4: Extension connection drivers +12. Create `ggsql-vscode/src/connections.ts` — DuckDB, Snowflake, generic ODBC drivers +13. Modify `extension.ts` — register drivers on activation +14. Test: open New Connection dialog, create DuckDB connection, verify Connections pane updates + +### Phase 5: Integration & polish +15. End-to-end test: New Connection dialog → kernel connection → Connections pane browsing +16. Handle edge cases: connection failures, reader not compiled in, comm lifecycle + +--- + +## Verification + +1. **Unit tests**: Meta-command parsing, ODBC URI parsing, Workbench credential detection, schema introspection queries +2. **Integration test**: Start kernel with `--reader duckdb://memory`, execute `-- @connect: duckdb://memory`, verify comm_open message on iopub +3. **Manual Positron test**: Open ggsql session → Connections pane shows DuckDB → expand to see schemas/tables/columns → New Connection dialog → create Snowflake connection → Connections pane updates +4. **Existing tests**: Run `cargo test` to ensure no regressions in parser/reader/writer + +## Key files to modify + +| File | Change | +|------|--------| +| `src/execute/mod.rs` | Change `prepare_data_with_reader` to `&dyn Reader` | +| `src/reader/mod.rs` | Remove `Self: Sized` from `execute()`, add odbc module | +| `src/reader/odbc.rs` | **NEW** — Generic ODBC reader | +| `src/reader/connection.rs` | Add ODBC variant | +| `src/Cargo.toml` | Add odbc feature + deps | +| `ggsql-jupyter/src/connection.rs` | **NEW** — Schema introspection | +| `ggsql-jupyter/src/kernel.rs` | Connection comm protocol | +| `ggsql-jupyter/src/executor.rs` | Dynamic reader switching, meta-commands | +| `ggsql-jupyter/src/main.rs` | `--reader` CLI arg | +| `ggsql-jupyter/src/lib.rs` | Add connection module | +| `ggsql-jupyter/Cargo.toml` | Add odbc feature | +| `ggsql-vscode/src/connections.ts` | **NEW** — Connection drivers | +| `ggsql-vscode/src/extension.ts` | Register connection drivers | +| `ggsql-vscode/src/manager.ts` | Pass reader URI to kernel | diff --git a/ggsql-jupyter/Cargo.toml b/ggsql-jupyter/Cargo.toml index 5b31ef8f..eadf493a 100644 --- a/ggsql-jupyter/Cargo.toml +++ b/ggsql-jupyter/Cargo.toml @@ -56,6 +56,12 @@ hex = "0.4" # UUID for message IDs uuid = { version = "1.0", features = ["v4"] } +[features] +default = ["all-readers"] +all-readers = ["sqlite", "odbc"] +odbc = ["ggsql/odbc"] +sqlite = ["ggsql/sqlite"] + [dev-dependencies] # Test utilities tokio-test = "0.4" diff --git a/ggsql-jupyter/src/connection.rs b/ggsql-jupyter/src/connection.rs new file mode 100644 index 00000000..0b28c80c --- /dev/null +++ b/ggsql-jupyter/src/connection.rs @@ -0,0 +1,213 @@ +//! Database schema introspection for the Positron Connections pane. +//! +//! Delegates introspection SQL to the reader's `SqlDialect`, which provides +//! backend-specific queries (e.g. `information_schema` for DuckDB/PostgreSQL, +//! `sqlite_master` / `PRAGMA` for SQLite). + +use ggsql::reader::Reader; +use serde::Serialize; +use serde_json::Value; + +/// An object in the schema hierarchy (catalog, schema, table, or view). +#[derive(Debug, Serialize)] +pub struct ObjectSchema { + pub name: String, + pub kind: String, +} + +/// A field (column) in a table. +#[derive(Debug, Serialize)] +pub struct FieldSchema { + pub name: String, + pub dtype: String, +} + +/// List objects at the given path depth. +/// +/// Path semantics (catalog → schema → table): +/// - `[]` → list catalogs +/// - `[catalog]` → list schemas in that catalog +/// - `[catalog, schema]` → list tables and views +pub fn list_objects(reader: &dyn Reader, path: &[String]) -> Result, String> { + match path.len() { + 0 => list_catalogs(reader), + 1 => list_schemas(reader, &path[0]), + 2 => list_tables(reader, &path[0], &path[1]), + _ => Ok(vec![]), + } +} + +/// List fields (columns) for the object at the given path. +/// +/// - `[catalog, schema, table]` → list columns +pub fn list_fields(reader: &dyn Reader, path: &[String]) -> Result, String> { + if path.len() == 3 { + list_columns(reader, &path[0], &path[1], &path[2]) + } else { + Ok(vec![]) + } +} + +/// Whether the path points to an object that contains data (table or view). +pub fn contains_data(path: &[Value]) -> bool { + path.last() + .and_then(|v| v.get("kind")) + .and_then(|k| k.as_str()) + .map(|k| k == "table" || k == "view") + .unwrap_or(false) +} + +fn list_catalogs(reader: &dyn Reader) -> Result, String> { + let sql = reader.dialect().sql_list_catalogs(); + let df = reader + .execute_sql(&sql) + .map_err(|e| format!("Failed to list catalogs: {}", e))?; + + let col = df + .column("catalog_name") + .or_else(|_| df.column("name")) + .map_err(|e| format!("Missing catalog_name/name column: {}", e))?; + + let mut catalogs = Vec::new(); + for i in 0..df.height() { + if let Ok(val) = col.get(i) { + let name = val.to_string().trim_matches('"').to_string(); + catalogs.push(ObjectSchema { + name, + kind: "catalog".to_string(), + }); + } + } + Ok(catalogs) +} + +fn list_schemas(reader: &dyn Reader, catalog: &str) -> Result, String> { + let sql = reader.dialect().sql_list_schemas(catalog); + let df = reader + .execute_sql(&sql) + .map_err(|e| format!("Failed to list schemas: {}", e))?; + + let col = df + .column("schema_name") + .or_else(|_| df.column("name")) + .map_err(|e| format!("Missing schema_name/name column: {}", e))?; + + let mut schemas = Vec::new(); + for i in 0..df.height() { + if let Ok(val) = col.get(i) { + let name = val.to_string().trim_matches('"').to_string(); + schemas.push(ObjectSchema { + name, + kind: "schema".to_string(), + }); + } + } + Ok(schemas) +} + +fn list_tables( + reader: &dyn Reader, + catalog: &str, + schema: &str, +) -> Result, String> { + let sql = reader.dialect().sql_list_tables(catalog, schema); + let df = reader + .execute_sql(&sql) + .map_err(|e| format!("Failed to list tables: {}", e))?; + + let name_col = df + .column("table_name") + .or_else(|_| df.column("name")) + .map_err(|e| format!("Missing table_name/name column: {}", e))?; + let type_col = df + .column("table_type") + .or_else(|_| df.column("kind")) + .map_err(|e| format!("Missing table_type/kind column: {}", e))?; + + let mut objects = Vec::new(); + for i in 0..df.height() { + if let (Ok(name_val), Ok(type_val)) = (name_col.get(i), type_col.get(i)) { + let name = name_val.to_string().trim_matches('"').to_string(); + let table_type = type_val.to_string().trim_matches('"').to_uppercase(); + let kind = if table_type.contains("VIEW") { + "view" + } else { + "table" + }; + objects.push(ObjectSchema { + name, + kind: kind.to_string(), + }); + } + } + Ok(objects) +} + +fn list_columns( + reader: &dyn Reader, + catalog: &str, + schema: &str, + table: &str, +) -> Result, String> { + let sql = reader.dialect().sql_list_columns(catalog, schema, table); + let df = reader + .execute_sql(&sql) + .map_err(|e| format!("Failed to list columns: {}", e))?; + + let name_col = df + .column("column_name") + .map_err(|e| format!("Missing column_name column: {}", e))?; + let type_col = df + .column("data_type") + .map_err(|e| format!("Missing data_type column: {}", e))?; + + let mut fields = Vec::new(); + for i in 0..df.height() { + if let (Ok(name_val), Ok(type_val)) = (name_col.get(i), type_col.get(i)) { + let name = name_val.to_string().trim_matches('"').to_string(); + let dtype = type_val.to_string().trim_matches('"').to_string(); + fields.push(FieldSchema { name, dtype }); + } + } + Ok(fields) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_contains_data_table() { + let path = vec![ + serde_json::json!({"name": "memory", "kind": "catalog"}), + serde_json::json!({"name": "main", "kind": "schema"}), + serde_json::json!({"name": "users", "kind": "table"}), + ]; + assert!(contains_data(&path)); + } + + #[test] + fn test_contains_data_schema() { + let path = vec![ + serde_json::json!({"name": "memory", "kind": "catalog"}), + serde_json::json!({"name": "main", "kind": "schema"}), + ]; + assert!(!contains_data(&path)); + } + + #[test] + fn test_contains_data_catalog() { + let path = vec![serde_json::json!({"name": "memory", "kind": "catalog"})]; + assert!(!contains_data(&path)); + } + + #[test] + fn test_contains_data_view() { + let path = vec![ + serde_json::json!({"name": "memory", "kind": "catalog"}), + serde_json::json!({"name": "main", "kind": "schema"}), + serde_json::json!({"name": "my_view", "kind": "view"}), + ]; + assert!(contains_data(&path)); + } +} diff --git a/ggsql-jupyter/src/data_explorer.rs b/ggsql-jupyter/src/data_explorer.rs new file mode 100644 index 00000000..6ecfbfb4 --- /dev/null +++ b/ggsql-jupyter/src/data_explorer.rs @@ -0,0 +1,1017 @@ +//! Data explorer backend for the Positron data viewer. +//! +//! Implements the `positron.dataExplorer` comm protocol, providing SQL-backed +//! paginated data access. No full table load — each `get_data_values` request +//! issues a `SELECT ... LIMIT/OFFSET` query. + +use ggsql::reader::Reader; +use serde_json::{json, Value}; + +/// Result of handling an RPC call. +pub struct RpcResponse { + /// The JSON-RPC result to send as the reply. + pub result: Value, + /// An optional event to send on iopub (e.g. `return_column_profiles`). + pub event: Option, +} + +/// An asynchronous event to send back on the comm after the RPC reply. +pub struct RpcEvent { + pub method: String, + pub params: Value, +} + +impl RpcResponse { + /// Create a simple reply with no async event. + pub fn reply(result: Value) -> Self { + Self { + result, + event: None, + } + } +} + +/// Cached column metadata for a table. +#[derive(Debug, Clone)] +pub struct ColumnInfo { + pub name: String, + /// Backend-specific type name (e.g. "INTEGER", "VARCHAR"). + pub type_name: String, + /// Positron display type (e.g. "integer", "string"). + pub type_display: String, +} + +/// State for one open data explorer comm. +pub struct DataExplorerState { + /// Fully qualified and quoted table path, e.g. `"memory"."main"."users"`. + table_path: String, + /// Display title shown in the data viewer tab. + title: String, + /// Cached column schemas. + columns: Vec, + /// Cached total row count. + num_rows: usize, +} + +impl DataExplorerState { + /// Open a data explorer for a table at the given connection path. + /// + /// Runs `SELECT COUNT(*)` and a column metadata query to cache schema + /// information. Does **not** load the full table into memory. + pub fn open(reader: &dyn Reader, path: &[String]) -> Result { + if path.len() < 3 { + return Err(format!( + "Expected [catalog, schema, table] path, got {} elements", + path.len() + )); + } + + let catalog = &path[0]; + let schema = &path[1]; + let table = &path[2]; + + let table_path = format!( + "\"{}\".\"{}\".\"{}\"", + catalog.replace('"', "\"\""), + schema.replace('"', "\"\""), + table.replace('"', "\"\""), + ); + + // Get row count + let count_sql = format!("SELECT COUNT(*) AS \"n\" FROM {}", table_path); + let count_df = reader + .execute_sql(&count_sql) + .map_err(|e| format!("Failed to count rows: {}", e))?; + let num_rows = count_df + .column("n") + .ok() + .and_then(|col| col.get(0).ok()) + .and_then(|val| { + // Polars AnyValue — try common integer representations + let s = format!("{}", val); + s.parse::().ok() + }) + .unwrap_or(0); + + // Get column metadata from information_schema + let columns_sql = reader.dialect().sql_list_columns(catalog, schema, table); + let columns_df = reader + .execute_sql(&columns_sql) + .map_err(|e| format!("Failed to list columns: {}", e))?; + + let name_col = columns_df + .column("column_name") + .map_err(|e| format!("Missing column_name: {}", e))?; + let type_col = columns_df + .column("data_type") + .map_err(|e| format!("Missing data_type: {}", e))?; + + let mut columns = Vec::new(); + for i in 0..columns_df.height() { + if let (Ok(name_val), Ok(type_val)) = (name_col.get(i), type_col.get(i)) { + let name = name_val.to_string().trim_matches('"').to_string(); + let raw_type = type_val.to_string().trim_matches('"').to_string(); + let type_display = sql_type_to_display(&raw_type).to_string(); + let type_name = clean_type_name(&raw_type); + columns.push(ColumnInfo { + name, + type_name, + type_display, + }); + } + } + + Ok(Self { + table_path, + title: table.clone(), + columns, + num_rows, + }) + } + + /// Dispatch a JSON-RPC method call. + /// + /// Returns the RPC result and an optional async event to send on iopub + /// (used by `get_column_profiles` to deliver results asynchronously). + pub fn handle_rpc(&self, method: &str, params: &Value, reader: &dyn Reader) -> RpcResponse { + match method { + "get_state" => RpcResponse::reply(self.get_state()), + "get_schema" => RpcResponse::reply(self.get_schema(params)), + "get_data_values" => RpcResponse::reply(self.get_data_values(params, reader)), + "get_column_profiles" => self.get_column_profiles(params, reader), + "set_row_filters" => { + // Stub: accept but ignore filters, return current shape + RpcResponse::reply(json!({ + "selected_num_rows": self.num_rows, + "had_errors": false + })) + } + "set_sort_columns" | "set_column_filters" | "search_schema" => { + RpcResponse::reply(json!(null)) + } + _ => { + tracing::warn!("Unhandled data explorer method: {}", method); + RpcResponse::reply(json!(null)) + } + } + } + + fn get_state(&self) -> Value { + let num_columns = self.columns.len(); + json!({ + "display_name": self.title, + "table_shape": { + "num_rows": self.num_rows, + "num_columns": num_columns + }, + "table_unfiltered_shape": { + "num_rows": self.num_rows, + "num_columns": num_columns + }, + "has_row_labels": false, + "column_filters": [], + "row_filters": [], + "sort_keys": [], + "supported_features": { + "search_schema": { + "support_status": "unsupported", + "supported_types": [] + }, + "set_column_filters": { + "support_status": "unsupported", + "supported_types": [] + }, + "set_row_filters": { + "support_status": "unsupported", + "supports_conditions": "unsupported", + "supported_types": [] + }, + "get_column_profiles": { + "support_status": "supported", + "supported_types": [ + {"profile_type": "null_count", "support_status": "supported"}, + {"profile_type": "summary_stats", "support_status": "supported"}, + {"profile_type": "small_histogram", "support_status": "supported"}, + {"profile_type": "small_frequency_table", "support_status": "supported"} + ] + }, + "set_sort_columns": { + "support_status": "unsupported" + }, + "export_data_selection": { + "support_status": "unsupported", + "supported_formats": [] + }, + "convert_to_code": { + "support_status": "unsupported" + } + } + }) + } + + fn get_schema(&self, params: &Value) -> Value { + let indices: Vec = params + .get("column_indices") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_u64().map(|n| n as usize)) + .collect() + }) + .unwrap_or_default(); + + let columns: Vec = indices + .iter() + .filter_map(|&idx| { + self.columns.get(idx).map(|col| { + json!({ + "column_name": col.name, + "column_index": idx, + "type_name": col.type_name, + "type_display": col.type_display + }) + }) + }) + .collect(); + + json!({ "columns": columns }) + } + + fn get_data_values(&self, params: &Value, reader: &dyn Reader) -> Value { + let selections = match params.get("columns").and_then(|v| v.as_array()) { + Some(arr) => arr, + None => return json!({ "columns": [] }), + }; + + // Determine the row range from the first selection's spec + let (first_index, last_index) = selections + .first() + .and_then(|sel| sel.get("spec")) + .map(|spec| { + let first = spec + .get("first_index") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + let last = spec.get("last_index").and_then(|v| v.as_u64()).unwrap_or(0) as usize; + (first, last) + }) + .unwrap_or((0, 0)); + + let limit = last_index.saturating_sub(first_index) + 1; + + // Collect requested column indices + let col_indices: Vec = selections + .iter() + .filter_map(|sel| { + sel.get("column_index") + .and_then(|v| v.as_u64()) + .map(|n| n as usize) + }) + .collect(); + + // Build column list for SELECT + let col_names: Vec = col_indices + .iter() + .filter_map(|&idx| { + self.columns + .get(idx) + .map(|col| format!("\"{}\"", col.name.replace('"', "\"\""))) + }) + .collect(); + + if col_names.is_empty() { + return json!({ "columns": [] }); + } + + let sql = format!( + "SELECT {} FROM {} LIMIT {} OFFSET {}", + col_names.join(", "), + self.table_path, + limit, + first_index, + ); + + let df = match reader.execute_sql(&sql) { + Ok(df) => df, + Err(e) => { + tracing::error!("get_data_values query failed: {}", e); + let empty: Vec> = col_indices.iter().map(|_| vec![]).collect(); + return json!({ "columns": empty }); + } + }; + + // Format each column's values as strings. + // Positron's ColumnValue is `number | string`: numbers are special + // value codes (0 = NULL, 1 = NA, 2 = NaN), strings are formatted data. + const SPECIAL_VALUE_NULL: i64 = 0; + + let columns: Vec> = (0..df.width()) + .map(|col_idx| { + let col = df.get_columns()[col_idx].clone(); + (0..df.height()) + .map(|row_idx| { + match col.get(row_idx) { + Ok(val) => { + if val.is_null() { + json!(SPECIAL_VALUE_NULL) + } else { + let s = format!("{}", val); + // Strip surrounding quotes from string values + let s = s.trim_matches('"'); + Value::String(s.to_string()) + } + } + Err(_) => json!(SPECIAL_VALUE_NULL), + } + }) + .collect() + }) + .collect(); + + json!({ "columns": columns }) + } + + /// Handle `get_column_profiles` — returns `{}` as the RPC result and sends + /// profile data back as an async `return_column_profiles` event. + fn get_column_profiles(&self, params: &Value, reader: &dyn Reader) -> RpcResponse { + let callback_id = params + .get("callback_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let requests = match params.get("profiles").and_then(|v| v.as_array()) { + Some(arr) => arr, + None => { + return RpcResponse { + result: json!({}), + event: Some(RpcEvent { + method: "return_column_profiles".into(), + params: json!({ + "callback_id": callback_id, + "profiles": [] + }), + }), + }; + } + }; + + let mut profiles = Vec::new(); + for req in requests { + let col_idx = req + .get("column_index") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + + let specs = req + .get("profiles") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default(); + + let profile = self.compute_column_profile(col_idx, &specs, reader); + profiles.push(profile); + } + + RpcResponse { + result: json!({}), + event: Some(RpcEvent { + method: "return_column_profiles".into(), + params: json!({ + "callback_id": callback_id, + "profiles": profiles + }), + }), + } + } + + /// Compute profile results for a single column. + fn compute_column_profile( + &self, + col_idx: usize, + specs: &[Value], + reader: &dyn Reader, + ) -> Value { + let col = match self.columns.get(col_idx) { + Some(c) => c, + None => return json!({}), + }; + + let mut wants_null_count = false; + let mut wants_summary = false; + let mut histogram_params: Option<&Value> = None; + let mut freq_table_params: Option<&Value> = None; + for spec in specs { + match spec + .get("profile_type") + .and_then(|v| v.as_str()) + .unwrap_or("") + { + "null_count" => wants_null_count = true, + "summary_stats" => wants_summary = true, + "small_histogram" => histogram_params = spec.get("params"), + "small_frequency_table" => freq_table_params = spec.get("params"), + _ => {} + } + } + + let dialect = reader.dialect(); + let quoted_col = format!("\"{}\"", col.name.replace('"', "\"\"")); + let display = col.type_display.as_str(); + + // Build a single SQL query that computes all needed aggregates. + // All expressions use ANSI SQL or existing dialect methods. + let mut select_parts = Vec::new(); + if wants_null_count { + select_parts.push(format!( + "SUM(CASE WHEN {} IS NULL THEN 1 ELSE 0 END) AS \"null_count\"", + quoted_col + )); + } + if wants_summary { + match display { + "integer" | "floating" => { + let float_type = dialect.number_type_name().unwrap_or("DOUBLE PRECISION"); + select_parts.push(format!("MIN({}) AS \"min_val\"", quoted_col)); + select_parts.push(format!("MAX({}) AS \"max_val\"", quoted_col)); + select_parts.push(format!( + "AVG(CAST({} AS {})) AS \"mean_val\"", + quoted_col, float_type + )); + // Stddev: fetch raw aggregates, compute in Rust + select_parts.push(format!( + "SUM(CAST({c} AS {t}) * CAST({c} AS {t})) AS \"sum_sq\"", + c = quoted_col, + t = float_type + )); + select_parts.push(format!( + "SUM(CAST({} AS {})) AS \"sum_val\"", + quoted_col, float_type + )); + select_parts.push(format!("COUNT({}) AS \"cnt\"", quoted_col)); + } + "boolean" => { + let true_lit = dialect.sql_boolean_literal(true); + let false_lit = dialect.sql_boolean_literal(false); + select_parts.push(format!( + "SUM(CASE WHEN {} = {} THEN 1 ELSE 0 END) AS \"true_count\"", + quoted_col, true_lit + )); + select_parts.push(format!( + "SUM(CASE WHEN {} = {} THEN 1 ELSE 0 END) AS \"false_count\"", + quoted_col, false_lit + )); + } + "string" => { + select_parts.push(format!("COUNT(DISTINCT {}) AS \"num_unique\"", quoted_col)); + select_parts.push(format!( + "SUM(CASE WHEN {} = '' THEN 1 ELSE 0 END) AS \"num_empty\"", + quoted_col + )); + } + "date" | "datetime" => { + select_parts.push(format!("MIN({}) AS \"min_val\"", quoted_col)); + select_parts.push(format!("MAX({}) AS \"max_val\"", quoted_col)); + select_parts.push(format!("COUNT(DISTINCT {}) AS \"num_unique\"", quoted_col)); + } + _ => {} + } + } + + if select_parts.is_empty() { + return json!({}); + } + + let sql = format!( + "SELECT {} FROM {}", + select_parts.join(", "), + self.table_path + ); + + let df = match reader.execute_sql(&sql) { + Ok(df) => df, + Err(e) => { + tracing::error!("Column profile query failed: {}", e); + return json!({}); + } + }; + + let get_str = |name: &str| -> Option { + df.column(name) + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| { + if v.is_null() { + None + } else { + Some(format!("{}", v).trim_matches('"').to_string()) + } + }) + }; + + let get_i64 = + |name: &str| -> Option { get_str(name).and_then(|s| s.parse::().ok()) }; + + let get_f64 = + |name: &str| -> Option { get_str(name).and_then(|s| s.parse::().ok()) }; + + let mut result = json!({}); + + if wants_null_count { + if let Some(n) = get_i64("null_count") { + result["null_count"] = json!(n); + } + } + + if wants_summary { + let stats = match display { + "integer" | "floating" => { + let mut number_stats = json!({}); + if let Some(v) = get_str("min_val") { + number_stats["min_value"] = json!(v); + } + if let Some(v) = get_str("max_val") { + number_stats["max_value"] = json!(v); + } + if let Some(v) = get_str("mean_val") { + number_stats["mean"] = json!(v); + } + // Compute sample stddev from raw aggregates + if let (Some(sum_sq), Some(sum_val), Some(cnt)) = + (get_f64("sum_sq"), get_f64("sum_val"), get_i64("cnt")) + { + if cnt > 1 { + let variance = + (sum_sq - sum_val * sum_val / cnt as f64) / (cnt - 1) as f64; + let stdev = variance.max(0.0).sqrt(); + number_stats["stdev"] = json!(format!("{}", stdev)); + } + } + // Median via dialect's sql_percentile (uses QUANTILE_CONT on + // DuckDB, NTILE fallback on other backends) + let col_name = col.name.replace('"', "\"\""); + let from_query = format!("SELECT * FROM {}", self.table_path); + let median_expr = dialect.sql_percentile(&col_name, 0.5, &from_query, &[]); + let median_sql = format!("SELECT {} AS \"median_val\"", median_expr); + if let Ok(median_df) = reader.execute_sql(&median_sql) { + if let Some(v) = median_df + .column("median_val") + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| { + if v.is_null() { + None + } else { + Some(format!("{}", v).trim_matches('"').to_string()) + } + }) + { + number_stats["median"] = json!(v); + } + } + json!({ + "type_display": display, + "number_stats": number_stats + }) + } + "boolean" => { + json!({ + "type_display": display, + "boolean_stats": { + "true_count": get_i64("true_count").unwrap_or(0), + "false_count": get_i64("false_count").unwrap_or(0) + } + }) + } + "string" => { + json!({ + "type_display": display, + "string_stats": { + "num_unique": get_i64("num_unique").unwrap_or(0), + "num_empty": get_i64("num_empty").unwrap_or(0) + } + }) + } + "date" => { + let mut date_stats = json!({}); + if let Some(v) = get_str("min_val") { + date_stats["min_date"] = json!(v); + } + if let Some(v) = get_str("max_val") { + date_stats["max_date"] = json!(v); + } + if let Some(n) = get_i64("num_unique") { + date_stats["num_unique"] = json!(n); + } + json!({ + "type_display": display, + "date_stats": date_stats + }) + } + "datetime" => { + let mut datetime_stats = json!({}); + if let Some(v) = get_str("min_val") { + datetime_stats["min_date"] = json!(v); + } + if let Some(v) = get_str("max_val") { + datetime_stats["max_date"] = json!(v); + } + if let Some(n) = get_i64("num_unique") { + datetime_stats["num_unique"] = json!(n); + } + json!({ + "type_display": display, + "datetime_stats": datetime_stats + }) + } + _ => json!({"type_display": display}), + }; + result["summary_stats"] = stats; + } + + // Compute histogram if requested (only for numeric types) + if let Some(params) = histogram_params { + if matches!(display, "integer" | "floating") { + if let Some(hist) = self.compute_histogram(col, params, reader) { + result["small_histogram"] = hist; + } + } + } + + // Compute frequency table if requested (for string/boolean types) + if let Some(params) = freq_table_params { + if matches!(display, "string" | "boolean") { + if let Some(ft) = self.compute_frequency_table(col, params, reader) { + result["small_frequency_table"] = ft; + } + } + } + + result + } + + /// Compute a histogram for a numeric column. + fn compute_histogram( + &self, + col: &ColumnInfo, + params: &Value, + reader: &dyn Reader, + ) -> Option { + let max_bins = params + .get("num_bins") + .and_then(|v| v.as_u64()) + .unwrap_or(20) as usize; + + if max_bins == 0 { + return None; + } + + let dialect = reader.dialect(); + let float_type = dialect.number_type_name().unwrap_or("DOUBLE PRECISION"); + let quoted_col = format!("\"{}\"", col.name.replace('"', "\"\"")); + let is_integer = col.type_display == "integer"; + + // Get min, max, count in one query + let bounds_sql = format!( + "SELECT \ + MIN(CAST({c} AS {t})) AS \"min_val\", \ + MAX(CAST({c} AS {t})) AS \"max_val\", \ + COUNT({c}) AS \"cnt\" \ + FROM {table} WHERE {c} IS NOT NULL", + c = quoted_col, + t = float_type, + table = self.table_path, + ); + + let bounds_df = reader.execute_sql(&bounds_sql).ok()?; + let get_f64 = |name: &str| -> Option { + bounds_df + .column(name) + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| { + if v.is_null() { + None + } else { + format!("{}", v).trim_matches('"').parse::().ok() + } + }) + }; + + let min_val = get_f64("min_val")?; + let max_val = get_f64("max_val")?; + let count = get_f64("cnt").unwrap_or(0.0) as usize; + + // Handle edge case: all values identical + if (max_val - min_val).abs() < f64::EPSILON { + return Some(json!({ + "bin_edges": [format!("{}", min_val), format!("{}", max_val)], + "bin_counts": [count as i64], + "quantiles": [] + })); + } + + // Determine actual bin count using Sturges' formula, capped at max_bins. + // For integers, also cap at (max - min + 1) to avoid sub-unit bins. + let mut num_bins = if count > 1 { + ((count as f64).log2().ceil() as usize + 1).max(1) + } else { + 1 + }; + if is_integer { + let int_range = (max_val - min_val) as usize + 1; + num_bins = num_bins.min(int_range); + } + num_bins = num_bins.min(max_bins).max(1); + + let bin_width = (max_val - min_val) / num_bins as f64; + + // Bin the data using FLOOR. Clamp the last bin to num_bins-1 so + // max value doesn't create an extra bin. + let hist_sql = format!( + "SELECT \ + CASE \ + WHEN \"bin\" >= {num_bins} THEN {last_bin} \ + ELSE \"bin\" \ + END AS \"clamped_bin\", \ + COUNT(*) AS \"cnt\" \ + FROM ( \ + SELECT FLOOR((CAST({c} AS {t}) - {min}) / {width}) AS \"bin\" \ + FROM {table} \ + WHERE {c} IS NOT NULL \ + ) AS \"__bins__\" \ + GROUP BY \"clamped_bin\" \ + ORDER BY \"clamped_bin\"", + c = quoted_col, + t = float_type, + table = self.table_path, + min = min_val, + width = bin_width, + num_bins = num_bins, + last_bin = num_bins - 1, + ); + + let hist_df = reader.execute_sql(&hist_sql).ok()?; + + // Build bin_edges: num_bins + 1 edges + let bin_edges: Vec = (0..=num_bins) + .map(|i| format!("{}", min_val + i as f64 * bin_width)) + .collect(); + + // Build bin_counts: fill from query results (sparse bins get 0) + let mut bin_counts = vec![0i64; num_bins]; + let bin_col = hist_df.column("clamped_bin").ok()?; + let cnt_col = hist_df.column("cnt").ok()?; + for i in 0..hist_df.height() { + if let (Ok(bin_val), Ok(cnt_val)) = (bin_col.get(i), cnt_col.get(i)) { + let bin_str = format!("{}", bin_val); + // Parse bin index — may be float (e.g., "3.0") on some backends + if let Ok(bin_idx) = bin_str.parse::() { + let idx = bin_idx as usize; + if idx < num_bins { + let count_str = format!("{}", cnt_val); + bin_counts[idx] = count_str.parse::().unwrap_or(0); + } + } + } + } + + // Compute requested quantiles + let quantiles_param = params + .get("quantiles") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default(); + + let mut quantile_results = Vec::new(); + let from_query = format!("SELECT * FROM {}", self.table_path); + let col_name = col.name.replace('"', "\"\""); + for q in &quantiles_param { + if let Some(q_val) = q.as_f64() { + let expr = dialect.sql_percentile(&col_name, q_val, &from_query, &[]); + let q_sql = format!("SELECT {} AS \"q_val\"", expr); + if let Ok(q_df) = reader.execute_sql(&q_sql) { + if let Some(v) = q_df + .column("q_val") + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| { + if v.is_null() { + None + } else { + Some(format!("{}", v).trim_matches('"').to_string()) + } + }) + { + quantile_results.push(json!({"q": q_val, "value": v})); + } + } + } + } + + Some(json!({ + "bin_edges": bin_edges, + "bin_counts": bin_counts, + "quantiles": quantile_results + })) + } + + /// Compute a frequency table for a string or boolean column. + fn compute_frequency_table( + &self, + col: &ColumnInfo, + params: &Value, + reader: &dyn Reader, + ) -> Option { + let limit = params.get("limit").and_then(|v| v.as_u64()).unwrap_or(8) as usize; + + let quoted_col = format!("\"{}\"", col.name.replace('"', "\"\"")); + + let sql = format!( + "SELECT {c} AS \"value\", COUNT(*) AS \"count\" \ + FROM {table} \ + WHERE {c} IS NOT NULL \ + GROUP BY {c} \ + ORDER BY COUNT(*) DESC \ + LIMIT {limit}", + c = quoted_col, + table = self.table_path, + limit = limit, + ); + + let df = reader.execute_sql(&sql).ok()?; + + let val_col = df.column("value").ok()?; + let cnt_col = df.column("count").ok()?; + + let mut values = Vec::new(); + let mut counts = Vec::new(); + let mut top_total: i64 = 0; + + for i in 0..df.height() { + if let (Ok(v), Ok(c)) = (val_col.get(i), cnt_col.get(i)) { + let val_str = format!("{}", v).trim_matches('"').to_string(); + let count: i64 = format!("{}", c).parse().unwrap_or(0); + values.push(Value::String(val_str)); + counts.push(count); + top_total += count; + } + } + + // Compute other_count: total non-null rows minus the top-K sum + let count_sql = format!( + "SELECT COUNT({c}) AS \"total\" FROM {table}", + c = quoted_col, + table = self.table_path, + ); + let other_count = reader + .execute_sql(&count_sql) + .ok() + .and_then(|df| { + df.column("total") + .ok() + .and_then(|c| c.get(0).ok()) + .and_then(|v| format!("{}", v).parse::().ok()) + }) + .map(|total| total - top_total) + .unwrap_or(0); + + Some(json!({ + "values": values, + "counts": counts, + "other_count": other_count + })) + } +} + +/// Map a SQL type name (from information_schema or SHOW COLUMNS) to a Positron display type. +/// +/// Handles both simple type names (e.g. "INTEGER", "VARCHAR") and Snowflake's +/// JSON format (e.g. `{"type":"FIXED","precision":38,"scale":0,...}`). +fn sql_type_to_display(type_name: &str) -> &'static str { + // Handle Snowflake JSON type format + if type_name.starts_with('{') { + if let Ok(obj) = serde_json::from_str::(type_name) { + if let Some(t) = obj.get("type").and_then(|v| v.as_str()) { + return match t { + "FIXED" => { + let scale = obj.get("scale").and_then(|v| v.as_i64()).unwrap_or(0); + if scale > 0 { + "floating" + } else { + "integer" + } + } + "REAL" | "FLOAT" => "floating", + "TEXT" => "string", + "BOOLEAN" => "boolean", + "DATE" => "date", + "TIMESTAMP_NTZ" | "TIMESTAMP_LTZ" | "TIMESTAMP_TZ" => "datetime", + "TIME" => "time", + "BINARY" => "string", + "VARIANT" | "OBJECT" | "ARRAY" => "string", + _ => "unknown", + }; + } + } + } + + // Simple type names (DuckDB, PostgreSQL, SQLite, etc.) + let upper = type_name.to_uppercase(); + let upper = upper.as_str(); + + if upper.contains("INT") { + return "integer"; + } + if upper.contains("FLOAT") + || upper.contains("DOUBLE") + || upper.contains("REAL") + || upper.contains("NUMERIC") + || upper.contains("DECIMAL") + { + return "floating"; + } + if upper.contains("BOOL") { + return "boolean"; + } + if upper.contains("TIMESTAMP") || upper.contains("DATETIME") { + return "datetime"; + } + if upper.contains("DATE") { + return "date"; + } + if upper.contains("TIME") { + return "time"; + } + if upper.contains("CHAR") + || upper.contains("TEXT") + || upper.contains("STRING") + || upper.contains("VARCHAR") + || upper.contains("CLOB") + { + return "string"; + } + if upper.contains("BLOB") || upper.contains("BINARY") || upper.contains("BYTE") { + return "string"; + } + + "unknown" +} + +/// Clean up a raw type name for display in the schema response. +/// +/// For Snowflake JSON types, extracts the `type` field (e.g. "NUMBER", "TEXT"). +/// For simple type names, returns as-is. +fn clean_type_name(type_name: &str) -> String { + if type_name.starts_with('{') { + if let Ok(obj) = serde_json::from_str::(type_name) { + if let Some(t) = obj.get("type").and_then(|v| v.as_str()) { + return match t { + "FIXED" => { + let scale = obj.get("scale").and_then(|v| v.as_i64()).unwrap_or(0); + if scale > 0 { + format!( + "NUMBER({},{})", + obj.get("precision").and_then(|v| v.as_i64()).unwrap_or(38), + scale + ) + } else { + "NUMBER".to_string() + } + } + other => other.to_string(), + }; + } + } + } + type_name.to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sql_type_to_display() { + assert_eq!(sql_type_to_display("INTEGER"), "integer"); + assert_eq!(sql_type_to_display("BIGINT"), "integer"); + assert_eq!(sql_type_to_display("SMALLINT"), "integer"); + assert_eq!(sql_type_to_display("TINYINT"), "integer"); + assert_eq!(sql_type_to_display("INT"), "integer"); + assert_eq!(sql_type_to_display("DOUBLE"), "floating"); + assert_eq!(sql_type_to_display("FLOAT"), "floating"); + assert_eq!(sql_type_to_display("REAL"), "floating"); + assert_eq!(sql_type_to_display("NUMERIC(10,2)"), "floating"); + assert_eq!(sql_type_to_display("DECIMAL(10,2)"), "floating"); + assert_eq!(sql_type_to_display("BOOLEAN"), "boolean"); + assert_eq!(sql_type_to_display("BOOL"), "boolean"); + assert_eq!(sql_type_to_display("VARCHAR"), "string"); + assert_eq!(sql_type_to_display("TEXT"), "string"); + assert_eq!(sql_type_to_display("DATE"), "date"); + assert_eq!(sql_type_to_display("TIMESTAMP"), "datetime"); + assert_eq!(sql_type_to_display("TIMESTAMP WITH TIME ZONE"), "datetime"); + assert_eq!(sql_type_to_display("TIME"), "time"); + assert_eq!(sql_type_to_display("BLOB"), "string"); + assert_eq!(sql_type_to_display("UNKNOWN_TYPE"), "unknown"); + } +} diff --git a/ggsql-jupyter/src/display.rs b/ggsql-jupyter/src/display.rs index 1e1dbe67..78823779 100644 --- a/ggsql-jupyter/src/display.rs +++ b/ggsql-jupyter/src/display.rs @@ -35,9 +35,24 @@ pub fn format_display_data(result: ExecutionResult) -> Option { Some(format_dataframe(df)) } } + ExecutionResult::ConnectionChanged { display_name, .. } => { + Some(format_connection_changed(&display_name)) + } } } +/// Format a connection-changed message +fn format_connection_changed(display_name: &str) -> Value { + let text = format!("Connected to {}", display_name); + json!({ + "data": { + "text/plain": text + }, + "metadata": {}, + "transient": {} + }) +} + /// Format Vega-Lite visualization as display_data fn format_vegalite(spec: String) -> Value { let spec_value: Value = serde_json::from_str(&spec).unwrap_or_else(|e| { diff --git a/ggsql-jupyter/src/executor.rs b/ggsql-jupyter/src/executor.rs index 1548f5e3..434345c3 100644 --- a/ggsql-jupyter/src/executor.rs +++ b/ggsql-jupyter/src/executor.rs @@ -2,10 +2,11 @@ //! //! This module handles the execution of ggsql queries using the existing //! ggsql library components (parser, DuckDB reader, Vega-Lite writer). +//! It supports dynamic reader switching via `-- @connect:` meta-commands. use anyhow::Result; use ggsql::{ - reader::{DuckDBReader, Reader}, + reader::{connection::parse_connection_string, DuckDBReader, Reader}, validate::validate, writer::{VegaLiteWriter, Writer}, }; @@ -20,39 +21,196 @@ pub enum ExecutionResult { Visualization { spec: String, // Vega-Lite JSON }, + /// Connection changed via meta-command + ConnectionChanged { uri: String, display_name: String }, } -/// Query executor maintaining persistent DuckDB connection +/// Create a reader from a connection URI string. +/// +/// Supported schemes: +/// - `duckdb://memory` or `duckdb://` (always available) +/// - `sqlite://` (requires `sqlite` feature) +/// - `odbc://...` (requires `odbc` feature) +pub fn create_reader(uri: &str) -> Result> { + use ggsql::reader::connection::ConnectionInfo; + + let info = parse_connection_string(uri)?; + match info { + ConnectionInfo::DuckDBMemory => { + let reader = DuckDBReader::from_connection_string("duckdb://memory")?; + Ok(Box::new(reader)) + } + ConnectionInfo::DuckDBFile(path) => { + let reader = DuckDBReader::from_connection_string(&format!("duckdb://{}", path))?; + Ok(Box::new(reader)) + } + #[cfg(feature = "odbc")] + ConnectionInfo::ODBC(conn_str) => { + let reader = + ggsql::reader::OdbcReader::from_connection_string(&format!("odbc://{}", conn_str))?; + Ok(Box::new(reader)) + } + #[cfg(feature = "sqlite")] + ConnectionInfo::SQLite(path) => { + let reader = + ggsql::reader::SqliteReader::from_connection_string(&format!("sqlite://{}", path))?; + Ok(Box::new(reader)) + } + _ => anyhow::bail!("Unsupported reader type for connection string: {}", uri), + } +} + +/// Generate a human-readable display name for a connection URI. +pub fn display_name_for_uri(uri: &str) -> String { + if uri == "duckdb://memory" { + return "DuckDB (memory)".to_string(); + } + if let Some(path) = uri.strip_prefix("duckdb://") { + return format!("DuckDB ({})", path); + } + if let Some(path) = uri.strip_prefix("sqlite://") { + if path.is_empty() { + return "SQLite (memory)".to_string(); + } + return format!("SQLite ({})", path); + } + if let Some(odbc) = uri.strip_prefix("odbc://") { + // Try to extract driver name from ODBC string + if let Some(driver_start) = odbc.to_lowercase().find("driver=") { + let rest = &odbc[driver_start + 7..]; + let driver = rest + .split(';') + .next() + .unwrap_or("ODBC") + .trim_matches(|c| c == '{' || c == '}'); + return format!("{} (ODBC)", driver); + } + return "ODBC".to_string(); + } + uri.to_string() +} + +/// Detect the database type name from a connection URI (e.g. "DuckDB", "Snowflake"). +pub fn type_name_for_uri(uri: &str) -> String { + if uri.starts_with("duckdb://") { + return "DuckDB".to_string(); + } + if uri.starts_with("sqlite://") { + return "SQLite".to_string(); + } + if let Some(odbc) = uri.strip_prefix("odbc://") { + if odbc.to_lowercase().contains("driver=snowflake") { + return "Snowflake".to_string(); + } + if odbc.to_lowercase().contains("driver={postgresql}") + || odbc.to_lowercase().contains("driver=postgresql") + { + return "PostgreSQL".to_string(); + } + return "ODBC".to_string(); + } + "Unknown".to_string() +} + +/// Extract the host portion from a connection URI. +pub fn host_for_uri(uri: &str) -> String { + if uri == "duckdb://memory" { + return "memory".to_string(); + } + if let Some(path) = uri.strip_prefix("duckdb://") { + return path.to_string(); + } + if let Some(path) = uri.strip_prefix("sqlite://") { + if path.is_empty() { + return "memory".to_string(); + } + return path.to_string(); + } + if let Some(odbc) = uri.strip_prefix("odbc://") { + // Try to extract server + if let Some(server_start) = odbc.to_lowercase().find("server=") { + let rest = &odbc[server_start + 7..]; + if let Some(host) = rest.split(';').next() { + return host.to_string(); + } + } + } + uri.to_string() +} + +/// The `-- @connect:` meta-command prefix. +const META_CONNECT_PREFIX: &str = "-- @connect:"; + +/// Parse a `-- @connect: ` meta-command, returning the URI if present. +pub fn parse_meta_command(code: &str) -> Option { + let trimmed = code.trim(); + trimmed + .strip_prefix(META_CONNECT_PREFIX) + .map(|rest| rest.trim().to_string()) +} + +/// Query executor maintaining persistent database connection pub struct QueryExecutor { - reader: DuckDBReader, + reader: Box, writer: VegaLiteWriter, + reader_uri: String, } impl QueryExecutor { - /// Create a new query executor with in-memory DuckDB database - pub fn new() -> Result { - tracing::info!("Initializing query executor with in-memory DuckDB"); - let reader = DuckDBReader::from_connection_string("duckdb://memory")?; + /// Create a new query executor with a given connection URI + pub fn new_with_uri(uri: &str) -> Result { + tracing::info!("Initializing query executor with reader: {}", uri); + let reader = create_reader(uri)?; let writer = VegaLiteWriter::new(); - Ok(Self { reader, writer }) + Ok(Self { + reader, + writer, + reader_uri: uri.to_string(), + }) } - /// Execute a ggsql query - /// - /// This handles both pure SQL queries and queries with VISUALISE clauses. - /// - /// # Arguments - /// - /// * `code` - The ggsql query to execute - /// - /// # Returns + /// Create a new query executor with the default in-memory DuckDB database + #[cfg(test)] + pub fn new() -> Result { + Self::new_with_uri("duckdb://memory") + } + + /// Get the current reader URI + pub fn reader_uri(&self) -> &str { + &self.reader_uri + } + + /// Get a reference to the current reader (for schema introspection) + pub fn reader(&self) -> &dyn Reader { + &*self.reader + } + + /// Swap the reader to a new connection, returning the old URI + pub fn swap_reader(&mut self, uri: &str) -> Result { + let new_reader = create_reader(uri)?; + self.reader = new_reader; + let old_uri = std::mem::replace(&mut self.reader_uri, uri.to_string()); + Ok(old_uri) + } + + /// Execute a ggsql query or meta-command /// - /// An ExecutionResult containing either a DataFrame (for pure SQL) or - /// a Visualization (for queries with VISUALISE clause) - pub fn execute(&self, code: &str) -> Result { + /// This handles: + /// - `-- @connect: ` meta-commands for switching readers + /// - Pure SQL queries (no VISUALISE) + /// - ggsql queries with VISUALISE clauses + pub fn execute(&mut self, code: &str) -> Result { tracing::debug!("Executing query: {} chars", code.len()); + // Check for meta-commands first + if let Some(uri) = parse_meta_command(code) { + tracing::info!("Meta-command: switching reader to {}", uri); + self.swap_reader(&uri)?; + let display_name = display_name_for_uri(&uri); + return Ok(ExecutionResult::ConnectionChanged { uri, display_name }); + } + // 1. Validate to check if there's a visualization let validated = validate(code)?; @@ -93,7 +251,7 @@ mod tests { #[test] fn test_simple_visualization() { - let executor = QueryExecutor::new().unwrap(); + let mut executor = QueryExecutor::new().unwrap(); let code = "SELECT 1 as x, 2 as y VISUALISE x, y DRAW point"; let result = executor.execute(code).unwrap(); @@ -102,7 +260,7 @@ mod tests { #[test] fn test_pure_sql() { - let executor = QueryExecutor::new().unwrap(); + let mut executor = QueryExecutor::new().unwrap(); let code = "SELECT 1 as x, 2 as y"; let result = executor.execute(code).unwrap(); @@ -111,10 +269,38 @@ mod tests { #[test] fn test_error_handling() { - let executor = QueryExecutor::new().unwrap(); + let mut executor = QueryExecutor::new().unwrap(); let code = "SELECT * FROM nonexistent_table"; let result = executor.execute(code); assert!(result.is_err()); } + + #[test] + fn test_parse_meta_command() { + assert_eq!( + parse_meta_command("-- @connect: duckdb://memory"), + Some("duckdb://memory".to_string()) + ); + assert_eq!( + parse_meta_command(" -- @connect: duckdb://my.db "), + Some("duckdb://my.db".to_string()) + ); + assert_eq!(parse_meta_command("SELECT 1"), None); + } + + #[test] + fn test_meta_command_switches_reader() { + let mut executor = QueryExecutor::new().unwrap(); + assert_eq!(executor.reader_uri(), "duckdb://memory"); + + let result = executor.execute("-- @connect: duckdb://memory").unwrap(); + assert!(matches!(result, ExecutionResult::ConnectionChanged { .. })); + } + + #[test] + fn test_display_name_for_uri() { + assert_eq!(display_name_for_uri("duckdb://memory"), "DuckDB (memory)"); + assert_eq!(display_name_for_uri("duckdb://my.db"), "DuckDB (my.db)"); + } } diff --git a/ggsql-jupyter/src/kernel.rs b/ggsql-jupyter/src/kernel.rs index 14c07340..2229048e 100644 --- a/ggsql-jupyter/src/kernel.rs +++ b/ggsql-jupyter/src/kernel.rs @@ -3,13 +3,16 @@ //! This module implements the Jupyter messaging protocol over ZeroMQ sockets, //! handling kernel_info, execute, and shutdown requests. +use crate::connection; +use crate::data_explorer::{DataExplorerState, RpcResponse}; use crate::display::format_display_data; -use crate::executor::QueryExecutor; +use crate::executor::{self, ExecutionResult, QueryExecutor}; use crate::message::{ConnectionInfo, JupyterMessage, MessageHeader}; use anyhow::Result; use hmac::{Hmac, Mac}; use serde_json::{json, Value}; use sha2::Sha256; +use std::collections::HashMap; use zeromq::{PubSocket, RepSocket, RouterSocket, Socket, SocketRecv, SocketSend}; type HmacSha256 = Hmac; @@ -32,11 +35,14 @@ pub struct KernelServer { variables_comm_id: Option, ui_comm_id: Option, plot_comm_id: Option, + connection_comm_id: Option, + // Open data explorer comms (comm_id → state) + data_explorer_comms: HashMap, } impl KernelServer { /// Create a new kernel server from connection info - pub async fn new(connection: ConnectionInfo) -> Result { + pub async fn new(connection: ConnectionInfo, reader_uri: &str) -> Result { tracing::info!("Initializing kernel server"); // Initialize sockets @@ -68,8 +74,8 @@ impl KernelServer { tracing::info!("Binding heartbeat socket to {}", hb_addr); heartbeat.bind(&hb_addr).await?; - // Create executor - let executor = QueryExecutor::new()?; + // Create executor with the specified reader + let executor = QueryExecutor::new_with_uri(reader_uri)?; // Generate session ID let session = uuid::Uuid::new_v4().to_string(); @@ -92,12 +98,17 @@ impl KernelServer { variables_comm_id: None, ui_comm_id: None, plot_comm_id: None, + connection_comm_id: None, + data_explorer_comms: HashMap::new(), }; // Send initial "starting" status on IOPub // This is required by Jupyter protocol - exactly once at process startup kernel.send_status_initial("starting").await?; + // Open initial connection comm so the Connections pane shows the database + kernel.open_connection_comm(reader_uri).await?; + Ok(kernel) } @@ -310,10 +321,17 @@ impl KernelServer { match result { Ok(exec_result) => { + // If the connection changed, open a new connection comm + let is_connection_changed = + matches!(&exec_result, ExecutionResult::ConnectionChanged { .. }); + if let ExecutionResult::ConnectionChanged { ref uri, .. } = &exec_result { + self.open_connection_comm(uri).await?; + } + // Send execute_result (not display_data) // Per Jupyter spec: execute_result includes execution_count // Only send if there's something to display (DDL returns None) - if !silent { + if !silent && !is_connection_changed { if let Some(display_data) = format_display_data(exec_result) { // Build message content, including output_location if present let mut content = json!({ @@ -498,6 +516,7 @@ impl KernelServer { self.send_status("busy", parent).await?; // Check if it's a JSON-RPC request + #[allow(clippy::if_same_then_else)] if let Some(method) = data["method"].as_str() { let rpc_id = &data["id"]; @@ -588,11 +607,47 @@ impl KernelServer { } // Handle positron.ui requests else if Some(comm_id.to_string()) == self.ui_comm_id { - tracing::info!("Received UI request: {} (ignoring)", method); + self.send_shell_reply( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "id": rpc_id, + "result": null + } + }), + parent, + identities, + ) + .await?; } // Handle positron.plot requests else if Some(comm_id.to_string()) == self.plot_comm_id { - tracing::info!("Received plot request: {} (ignoring)", method); + self.send_shell_reply( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "id": rpc_id, + "result": null + } + }), + parent, + identities, + ) + .await?; + } + // Handle positron.connection requests + else if Some(comm_id.to_string()) == self.connection_comm_id { + self.handle_connection_rpc(method, rpc_id, comm_id, parent, identities) + .await?; + } + // Handle positron.dataExplorer requests + else if self.data_explorer_comms.contains_key(comm_id) { + self.handle_data_explorer_rpc(method, rpc_id, comm_id, parent, identities) + .await?; } // Unknown comm else { @@ -634,6 +689,16 @@ impl KernelServer { comms[id] = json!({"target_name": "positron.plot"}); } } + if let Some(id) = &self.connection_comm_id { + if target_name.is_none() || target_name == Some("positron.connection") { + comms[id] = json!({"target_name": "positron.connection"}); + } + } + for id in self.data_explorer_comms.keys() { + if target_name.is_none() || target_name == Some("positron.dataExplorer") { + comms[id] = json!({"target_name": "positron.dataExplorer"}); + } + } tracing::info!( "Returning comms: {}", @@ -677,6 +742,11 @@ impl KernelServer { } else if Some(comm_id.to_string()) == self.plot_comm_id { tracing::info!("Closing positron.plot comm"); self.plot_comm_id = None; + } else if Some(comm_id.to_string()) == self.connection_comm_id { + tracing::info!("Closing positron.connection comm"); + self.connection_comm_id = None; + } else if self.data_explorer_comms.remove(comm_id).is_some() { + tracing::info!("Closing data explorer comm: {}", comm_id); } else { tracing::warn!("Close for unknown comm_id: {}", comm_id); } @@ -685,6 +755,248 @@ impl KernelServer { Ok(()) } + /// Open (or replace) a `positron.connection` comm for the current reader. + /// + /// The kernel initiates this comm (backend-initiated). If an existing + /// connection comm is open, it is closed first. + async fn open_connection_comm(&mut self, uri: &str) -> Result<()> { + // Close existing connection comm if any + if let Some(old_id) = self.connection_comm_id.take() { + tracing::info!("Closing old connection comm: {}", old_id); + let close_msg = self.create_message("comm_close", json!({ "comm_id": old_id }), None); + let zmq_msg = self.serialize_message_with_topic(&close_msg, "comm_close")?; + self.iopub.send(zmq_msg).await?; + } + + let comm_id = uuid::Uuid::new_v4().to_string(); + let display_name = executor::display_name_for_uri(uri); + let type_name = executor::type_name_for_uri(uri); + let host = executor::host_for_uri(uri); + let meta_command = format!("-- @connect: {}", uri); + + tracing::info!( + "Opening positron.connection comm: {} ({})", + comm_id, + display_name + ); + + let msg = self.create_message( + "comm_open", + json!({ + "comm_id": comm_id, + "target_name": "positron.connection", + "data": { + "name": display_name, + "language_id": "ggsql", + "host": host, + "type": type_name, + "code": meta_command + } + }), + None, + ); + let zmq_msg = self.serialize_message_with_topic(&msg, "comm_open")?; + self.iopub.send(zmq_msg).await?; + + self.connection_comm_id = Some(comm_id); + Ok(()) + } + + /// Handle JSON-RPC requests on the connection comm + async fn handle_connection_rpc( + &mut self, + method: &str, + rpc_id: &Value, + comm_id: &str, + parent: &JupyterMessage, + identities: &[Vec], + ) -> Result<()> { + tracing::info!("Connection RPC: {}", method); + + let params = &parent.content["data"]["params"]; + + let result = match method { + "list_objects" => { + let path: Vec = params["path"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|v| { + v.get("name") + .and_then(|n| n.as_str()) + .map(|s| s.to_string()) + }) + .collect() + }) + .unwrap_or_default(); + match connection::list_objects(self.executor.reader(), &path) { + Ok(objects) => json!(objects), + Err(e) => { + tracing::error!("list_objects failed: {}", e); + json!([]) + } + } + } + "list_fields" => { + let path: Vec = params["path"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|v| { + v.get("name") + .and_then(|n| n.as_str()) + .map(|s| s.to_string()) + }) + .collect() + }) + .unwrap_or_default(); + match connection::list_fields(self.executor.reader(), &path) { + Ok(fields) => json!(fields), + Err(e) => { + tracing::error!("list_fields failed: {}", e); + json!([]) + } + } + } + "contains_data" => { + let path: Vec = params["path"].as_array().cloned().unwrap_or_default(); + let has_data = connection::contains_data(&path); + json!(has_data) + } + "get_icon" => json!(""), + "preview_object" => { + let path: Vec = params["path"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|v| { + v.get("name") + .and_then(|n| n.as_str()) + .map(|s| s.to_string()) + }) + .collect() + }) + .unwrap_or_default(); + + match DataExplorerState::open(self.executor.reader(), &path) { + Ok(state) => { + let de_comm_id = uuid::Uuid::new_v4().to_string(); + let title = path.last().cloned().unwrap_or_default(); + + // Send comm_open on iopub to open the data viewer + let msg = self.create_message( + "comm_open", + json!({ + "comm_id": de_comm_id, + "target_name": "positron.dataExplorer", + "data": { + "title": title + } + }), + Some(parent), + ); + let zmq_msg = self.serialize_message_with_topic(&msg, "comm_open")?; + self.iopub.send(zmq_msg).await?; + + tracing::info!("Opened data explorer comm: {} for {}", de_comm_id, title); + self.data_explorer_comms.insert(de_comm_id, state); + } + Err(e) => { + tracing::error!("preview_object failed: {}", e); + } + } + json!(null) + } + "get_metadata" => { + let uri = self.executor.reader_uri(); + json!({ + "name": executor::display_name_for_uri(uri), + "language_id": "ggsql", + "host": executor::host_for_uri(uri), + "type": executor::type_name_for_uri(uri), + "code": format!("-- @connect: {}", uri) + }) + } + _ => { + tracing::warn!("Unknown connection method: {}", method); + json!(null) + } + }; + + self.send_shell_reply( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "id": rpc_id, + "result": result + } + }), + parent, + identities, + ) + .await?; + + Ok(()) + } + + /// Handle JSON-RPC requests on a data explorer comm + async fn handle_data_explorer_rpc( + &mut self, + method: &str, + rpc_id: &Value, + comm_id: &str, + parent: &JupyterMessage, + identities: &[Vec], + ) -> Result<()> { + tracing::info!("Data explorer RPC: {}", method); + + let params = &parent.content["data"]["params"]; + + let RpcResponse { result, event } = + if let Some(state) = self.data_explorer_comms.get(comm_id) { + state.handle_rpc(method, params, self.executor.reader()) + } else { + RpcResponse::reply(json!(null)) + }; + + // Send the RPC reply + self.send_shell_reply( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "id": rpc_id, + "result": result + } + }), + parent, + identities, + ) + .await?; + + // Send async event on iopub if present (e.g. return_column_profiles) + if let Some(evt) = event { + self.send_iopub( + "comm_msg", + json!({ + "comm_id": comm_id, + "data": { + "jsonrpc": "2.0", + "method": evt.method, + "params": evt.params + } + }), + parent, + ) + .await?; + } + + Ok(()) + } + /// Send a message on the IOPub channel async fn send_iopub( &mut self, diff --git a/ggsql-jupyter/src/lib.rs b/ggsql-jupyter/src/lib.rs index 40861748..5d2aaf55 100644 --- a/ggsql-jupyter/src/lib.rs +++ b/ggsql-jupyter/src/lib.rs @@ -2,6 +2,8 @@ //! //! This module exposes the internal components for testing. +pub mod connection; +pub mod data_explorer; pub mod display; pub mod executor; pub mod message; diff --git a/ggsql-jupyter/src/main.rs b/ggsql-jupyter/src/main.rs index 316ab8ba..896ce756 100644 --- a/ggsql-jupyter/src/main.rs +++ b/ggsql-jupyter/src/main.rs @@ -2,6 +2,8 @@ //! //! A Jupyter kernel for executing ggsql queries with rich Vega-Lite visualizations. +mod connection; +mod data_explorer; mod display; mod executor; mod kernel; @@ -22,6 +24,10 @@ struct Args { #[arg(short = 'f', long = "connection-file")] connection_file: Option, + /// Database connection URI (e.g. "duckdb://memory") + #[arg(long, default_value = "duckdb://memory")] + reader: String, + /// Install the kernel spec #[arg(long)] install: bool, @@ -69,7 +75,7 @@ async fn main() -> Result<()> { tracing::info!("Creating kernel server"); // Create and run kernel - let mut kernel = kernel::KernelServer::new(connection).await?; + let mut kernel = kernel::KernelServer::new(connection, &args.reader).await?; tracing::info!("Kernel ready, starting event loop"); diff --git a/ggsql-python/src/lib.rs b/ggsql-python/src/lib.rs index d477275e..6e5c543c 100644 --- a/ggsql-python/src/lib.rs +++ b/ggsql-python/src/lib.rs @@ -164,6 +164,10 @@ impl Reader for PyReaderBridge { }) } + fn execute(&self, query: &str) -> ggsql::Result { + ggsql::reader::execute_with_reader(self, query) + } + fn dialect(&self) -> &dyn ggsql::reader::SqlDialect { &ANSI_DIALECT } diff --git a/ggsql-vscode/package-lock.json b/ggsql-vscode/package-lock.json index 7e5b7c3f..d61cfc95 100644 --- a/ggsql-vscode/package-lock.json +++ b/ggsql-vscode/package-lock.json @@ -8,6 +8,9 @@ "name": "ggsql", "version": "0.1.9", "license": "MIT", + "dependencies": { + "toml": "^3.0.0" + }, "devDependencies": { "@posit-dev/positron": "^0.2.2", "@types/node": "^18.x", @@ -3477,6 +3480,12 @@ "url": "https://github.com/sponsors/SuperchupuDev" } }, + "node_modules/toml": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/toml/-/toml-3.0.0.tgz", + "integrity": "sha512-y/mWCZinnvxjTKYhJ+pYxwD0mRLVvOtdS2Awbgxln6iEnt4rk0yBxeSBHkGJcPucRiG0e55mwWp+g/05rsrd6w==", + "license": "MIT" + }, "node_modules/ts-api-utils": { "version": "2.5.0", "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-2.5.0.tgz", diff --git a/ggsql-vscode/package.json b/ggsql-vscode/package.json index e9847c02..7032bf45 100644 --- a/ggsql-vscode/package.json +++ b/ggsql-vscode/package.json @@ -77,6 +77,9 @@ "check-types": "tsc --noEmit", "lint": "eslint src --ext ts" }, + "dependencies": { + "toml": "^3.0.0" + }, "devDependencies": { "@posit-dev/positron": "^0.2.2", "@types/node": "^18.x", diff --git a/ggsql-vscode/src/connections.ts b/ggsql-vscode/src/connections.ts new file mode 100644 index 00000000..0df0641f --- /dev/null +++ b/ggsql-vscode/src/connections.ts @@ -0,0 +1,397 @@ +/* + * Connection Drivers for Positron's Connections pane + * + * Registers drivers that let users create database connections via the + * "New Connection" dialog. Each driver generates a `-- @connect:` meta-command + * that the ggsql-jupyter kernel interprets to switch readers. + */ + +import * as os from 'os'; +import * as path from 'path'; +import * as fs from 'fs'; +import * as toml from 'toml'; +import type * as positron from '@posit-dev/positron'; + +type PositronApi = positron.PositronApi; +type ConnectionsDriverMetadata = positron.ConnectionsDriverMetadata & { description?: string }; + +/** + * Create the set of ggsql connection drivers to register with Positron. + */ +export function createConnectionDrivers( + positronApi: PositronApi +): positron.ConnectionsDriver[] { + return [ + createDuckDBDriver(positronApi), + createSnowflakeDefaultDriver(positronApi), + createSnowflakePasswordDriver(positronApi), + createSnowflakeSSODriver(positronApi), + createSnowflakePATDriver(positronApi), + createOdbcDriver(positronApi), + ]; +} + +// ============================================================================ +// DuckDB +// ============================================================================ + +/** + * DuckDB connection driver. + * + * Inputs: optional database file path (empty = in-memory). + */ +function createDuckDBDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-duckdb', + metadata: { + languageId: 'ggsql', + name: 'DuckDB', + inputs: [ + { + id: 'database', + label: 'Database', + type: 'string', + value: '', + }, + ], + }, + generateCode: (inputs) => { + const db = inputs.find((i) => i.id === 'database')?.value?.trim(); + if (!db) { + return '-- @connect: duckdb://memory'; + } + return `-- @connect: duckdb://${db}`; + }, + connect: async (code: string) => { + await positronApi.runtime.executeCode('ggsql', code, false); + }, + }; +} + +// ============================================================================ +// Snowflake — shared helpers +// ============================================================================ + +interface SnowflakeConnectionEntry { + name: string; + account?: string; +} + +/** + * Find the Snowflake connections.toml file, checking standard locations. + */ +function findSnowflakeConnectionsToml(): string | undefined { + const candidates: string[] = []; + + // 1. $SNOWFLAKE_HOME/connections.toml + const snowflakeHome = process.env.SNOWFLAKE_HOME; + if (snowflakeHome) { + candidates.push(path.join(snowflakeHome, 'connections.toml')); + } + + // 2. ~/.snowflake/connections.toml + const home = os.homedir(); + candidates.push(path.join(home, '.snowflake', 'connections.toml')); + + // 3. Platform-specific paths + if (process.platform === 'darwin') { + candidates.push( + path.join(home, 'Library', 'Application Support', 'snowflake', 'connections.toml') + ); + } else if (process.platform === 'linux') { + const xdgConfig = process.env.XDG_CONFIG_HOME || path.join(home, '.config'); + candidates.push(path.join(xdgConfig, 'snowflake', 'connections.toml')); + } else if (process.platform === 'win32') { + candidates.push( + path.join(home, 'AppData', 'Local', 'snowflake', 'connections.toml') + ); + } + + for (const candidate of candidates) { + if (fs.existsSync(candidate)) { + return candidate; + } + } + return undefined; +} + +/** + * Read Snowflake connection entries from connections.toml. + */ +function readSnowflakeConnections(): { + connections: SnowflakeConnectionEntry[]; + defaultConnection?: string; +} { + const tomlPath = findSnowflakeConnectionsToml(); + if (!tomlPath) { + return { connections: [] }; + } + + try { + const content = fs.readFileSync(tomlPath, 'utf-8'); + const parsed = toml.parse(content); + + const defaultConnection = + process.env.SNOWFLAKE_DEFAULT_CONNECTION_NAME || + parsed.default_connection_name || + undefined; + + const connections: SnowflakeConnectionEntry[] = Object.keys(parsed) + .filter( + (key) => + key !== 'default_connection_name' && + typeof parsed[key] === 'object' && + parsed[key] !== null + ) + .map((name) => ({ + name, + account: parsed[name].account as string | undefined, + })); + + return { connections, defaultConnection }; + } catch { + return { connections: [] }; + } +} + +/** + * Build an ODBC connection string for Snowflake with the given parts. + */ +function buildSnowflakeOdbc(parts: Record): string { + let connStr = `Driver=Snowflake;Server=${parts.account}.snowflakecomputing.com`; + if (parts.uid) { + connStr += `;UID=${parts.uid}`; + } + if (parts.pwd) { + connStr += `;PWD=${parts.pwd}`; + } + if (parts.authenticator) { + connStr += `;Authenticator=${parts.authenticator}`; + } + if (parts.token) { + connStr += `;Token=${parts.token}`; + } + if (parts.warehouse) { + connStr += `;Warehouse=${parts.warehouse}`; + } + if (parts.database) { + connStr += `;Database=${parts.database}`; + } + if (parts.schema) { + connStr += `;Schema=${parts.schema}`; + } + return `-- @connect: odbc://${connStr}`; +} + +function snowflakeConnect(positronApi: PositronApi) { + return async (code: string) => { + await positronApi.runtime.executeCode('ggsql', code, false); + }; +} + +// ============================================================================ +// Snowflake — Default Connection (connections.toml) +// ============================================================================ + +function createSnowflakeDefaultDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + const { connections, defaultConnection } = readSnowflakeConnections(); + + let inputs: positron.ConnectionsInput[]; + if (connections.length > 0) { + const defaultValue = + defaultConnection || + (connections.find((c) => c.name === 'default')?.name ?? connections[0].name); + + inputs = [ + { + id: 'connection_name', + label: 'Connection Name', + type: 'option', + options: connections.map((conn) => ({ + identifier: conn.name, + title: conn.account + ? `${conn.name} (${conn.account})` + : conn.name, + })), + value: defaultValue, + }, + ]; + } else { + inputs = [ + { + id: 'connection_name', + label: 'Connection Name', + type: 'string', + value: 'default', + }, + ]; + } + + return { + driverId: 'ggsql-snowflake-default', + metadata: { + languageId: 'ggsql', + name: 'Snowflake', + description: 'Default Connection (connections.toml)', + inputs, + } as ConnectionsDriverMetadata, + generateCode: (inputs) => { + const name = + inputs.find((i) => i.id === 'connection_name')?.value?.trim() || 'default'; + return `-- @connect: odbc://Driver=Snowflake;ConnectionName=${name}`; + }, + connect: snowflakeConnect(positronApi), + }; +} + +// ============================================================================ +// Snowflake — Username/Password +// ============================================================================ + +function createSnowflakePasswordDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-snowflake-password', + metadata: { + languageId: 'ggsql', + name: 'Snowflake', + description: 'Username/Password', + inputs: [ + { id: 'account', label: 'Account', type: 'string' }, + { id: 'user', label: 'User', type: 'string' }, + { id: 'password', label: 'Password', type: 'string' }, + { id: 'warehouse', label: 'Warehouse', type: 'string' }, + { id: 'database', label: 'Database', type: 'string', value: '' }, + { id: 'schema', label: 'Schema', type: 'string', value: '' }, + ], + } as ConnectionsDriverMetadata, + generateCode: (inputs) => { + const get = (id: string) => + inputs.find((i) => i.id === id)?.value?.trim() || ''; + return buildSnowflakeOdbc({ + account: get('account'), + uid: get('user'), + pwd: get('password'), + warehouse: get('warehouse'), + database: get('database') || undefined, + schema: get('schema') || undefined, + }); + }, + connect: snowflakeConnect(positronApi), + }; +} + +// ============================================================================ +// Snowflake — External Browser (SSO) +// ============================================================================ + +function createSnowflakeSSODriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-snowflake-sso', + metadata: { + languageId: 'ggsql', + name: 'Snowflake', + description: 'External Browser (SSO)', + inputs: [ + { id: 'account', label: 'Account', type: 'string' }, + { id: 'user', label: 'User', type: 'string', value: '' }, + { id: 'warehouse', label: 'Warehouse', type: 'string' }, + { id: 'database', label: 'Database', type: 'string', value: '' }, + { id: 'schema', label: 'Schema', type: 'string', value: '' }, + ], + } as ConnectionsDriverMetadata, + generateCode: (inputs) => { + const get = (id: string) => + inputs.find((i) => i.id === id)?.value?.trim() || ''; + return buildSnowflakeOdbc({ + account: get('account'), + uid: get('user') || undefined, + authenticator: 'externalbrowser', + warehouse: get('warehouse'), + database: get('database') || undefined, + schema: get('schema') || undefined, + }); + }, + connect: snowflakeConnect(positronApi), + }; +} + +// ============================================================================ +// Snowflake — Programmatic Access Token (PAT) +// ============================================================================ + +function createSnowflakePATDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-snowflake-pat', + metadata: { + languageId: 'ggsql', + name: 'Snowflake', + description: 'Programmatic Access Token (PAT)', + inputs: [ + { id: 'account', label: 'Account', type: 'string' }, + { id: 'token', label: 'Token', type: 'string' }, + { id: 'warehouse', label: 'Warehouse', type: 'string' }, + { id: 'database', label: 'Database', type: 'string', value: '' }, + { id: 'schema', label: 'Schema', type: 'string', value: '' }, + ], + } as ConnectionsDriverMetadata, + generateCode: (inputs) => { + const get = (id: string) => + inputs.find((i) => i.id === id)?.value?.trim() || ''; + return buildSnowflakeOdbc({ + account: get('account'), + authenticator: 'programmatic_access_token', + token: get('token'), + warehouse: get('warehouse'), + database: get('database') || undefined, + schema: get('schema') || undefined, + }); + }, + connect: snowflakeConnect(positronApi), + }; +} + +// ============================================================================ +// Generic ODBC +// ============================================================================ + +/** + * Generic ODBC connection driver. + * + * Lets users paste a raw ODBC connection string. + */ +function createOdbcDriver( + positronApi: PositronApi +): positron.ConnectionsDriver { + return { + driverId: 'ggsql-odbc', + metadata: { + languageId: 'ggsql', + name: 'ODBC', + inputs: [ + { + id: 'connection_string', + label: 'Connection String', + type: 'string', + }, + ], + }, + generateCode: (inputs) => { + const connStr = + inputs.find((i) => i.id === 'connection_string')?.value ?? ''; + return `-- @connect: odbc://${connStr}`; + }, + connect: async (code: string) => { + await positronApi.runtime.executeCode('ggsql', code, false); + }, + }; +} diff --git a/ggsql-vscode/src/extension.ts b/ggsql-vscode/src/extension.ts index 6d6d1ec5..54b76edf 100644 --- a/ggsql-vscode/src/extension.ts +++ b/ggsql-vscode/src/extension.ts @@ -8,6 +8,7 @@ import * as vscode from 'vscode'; import { tryAcquirePositronApi } from '@posit-dev/positron'; import { GgsqlRuntimeManager } from './manager'; +import { createConnectionDrivers } from './connections'; // Output channel for logging const outputChannel = vscode.window.createOutputChannel('ggsql'); @@ -42,6 +43,15 @@ export function activate(context: vscode.ExtensionContext): void { context.subscriptions.push(disposable); log('ggsql runtime manager registered successfully'); + + // Register connection drivers for the Connections pane + const drivers = createConnectionDrivers(positronApi); + for (const driver of drivers) { + const driverDisposable = positronApi.connections.registerConnectionDriver(driver); + context.subscriptions.push(driverDisposable); + } + + log(`Registered ${drivers.length} connection drivers`); } /** diff --git a/ggsql-vscode/src/manager.ts b/ggsql-vscode/src/manager.ts index da95c0c2..5cf159de 100644 --- a/ggsql-vscode/src/manager.ts +++ b/ggsql-vscode/src/manager.ts @@ -117,7 +117,7 @@ function generateMetadata( * * @param workspacePath - Optional workspace path to use as the kernel's working directory */ -function createKernelSpec(workspacePath?: string): JupyterKernelSpec { +function createKernelSpec(workspacePath?: string, readerUri?: string): JupyterKernelSpec { const kernelPath = getKernelPath(); return { @@ -131,11 +131,20 @@ function createKernelSpec(workspacePath?: string): JupyterKernelSpec { startKernel: async (session: JupyterSession, kernel: JupyterKernel) => { kernel.log(`Starting ggsql kernel with connection file: ${session.state.connectionFile}`); kernel.log(`Working directory: ${workspacePath ?? 'inherited from parent'}`); + if (readerUri) { + kernel.log(`Reader URI: ${readerUri}`); + } const connectionFile = session.state.connectionFile; + // Build arguments + const args = ['-f', connectionFile]; + if (readerUri) { + args.push('--reader', readerUri); + } + // Start the kernel process - const proc = cp.spawn(kernelPath, ['-f', connectionFile], { + const proc = cp.spawn(kernelPath, args, { stdio: ['ignore', 'pipe', 'pipe'], detached: false, cwd: workspacePath diff --git a/src/Cargo.toml b/src/Cargo.toml index 4f45fea2..d8462fa0 100644 --- a/src/Cargo.toml +++ b/src/Cargo.toml @@ -39,6 +39,8 @@ duckdb = { workspace = true, optional = true } arrow = { workspace = true, optional = true } postgres = { workspace = true, optional = true } rusqlite = { workspace = true, optional = true } +odbc-api = { workspace = true, optional = true } +toml_edit = { workspace = true, optional = true } # Writers plotters = { workspace = true, optional = true } @@ -73,21 +75,23 @@ pyo3 = { workspace = true, optional = true } [dev-dependencies] jsonschema = "0.44" proptest.workspace = true +tempfile = "3.8" ureq = "3" [features] -default = ["duckdb", "sqlite", "vegalite", "ipc", "parquet", "builtin-data"] +default = ["duckdb", "sqlite", "vegalite", "ipc", "parquet", "builtin-data", "odbc"] ipc = ["polars/ipc"] duckdb = ["dep:duckdb", "dep:arrow"] parquet = ["polars/parquet"] postgres = ["dep:postgres"] sqlite = ["dep:rusqlite"] +odbc = ["dep:odbc-api", "dep:toml_edit"] vegalite = [] ggplot2 = [] builtin-data = [] python = ["dep:pyo3"] rest-api = ["dep:axum", "dep:tokio", "dep:tower-http", "dep:tracing", "dep:tracing-subscriber", "duckdb", "vegalite"] -all-readers = ["duckdb", "postgres", "sqlite"] +all-readers = ["duckdb", "postgres", "sqlite", "odbc"] all-writers = ["vegalite", "ggplot2", "plotters"] # cargo-packager configuration for cross-platform installers diff --git a/src/execute/cte.rs b/src/execute/cte.rs index 5a6b665f..983b05e3 100644 --- a/src/execute/cte.rs +++ b/src/execute/cte.rs @@ -94,7 +94,7 @@ pub fn transform_cte_references(sql: &str, cte_names: &HashSet) -> Strin let mut result = sql.to_string(); for cte_name in cte_names { - let temp_table_name = naming::cte_table(cte_name); + let temp_table_name = format!("\"{}\"", naming::cte_table(cte_name)); // Replace table references: FROM cte_name, JOIN cte_name, cte_name.column // Use word boundary matching to avoid replacing substrings @@ -360,7 +360,7 @@ mod tests { ( "SELECT * FROM sales WHERE year = 2024", vec!["sales"], - vec!["FROM __ggsql_cte_sales_", "__ WHERE year = 2024"], + vec!["FROM \"__ggsql_cte_sales_", "__\" WHERE year = 2024"], None, ), // Multiple CTE references with qualified columns @@ -368,10 +368,10 @@ mod tests { "SELECT sales.date, targets.revenue FROM sales JOIN targets ON sales.id = targets.id", vec!["sales", "targets"], vec![ - "FROM __ggsql_cte_sales_", - "JOIN __ggsql_cte_targets_", - "__ggsql_cte_sales_", // qualified reference sales.date - "__ggsql_cte_targets_", // qualified reference targets.revenue + "FROM \"__ggsql_cte_sales_", + "JOIN \"__ggsql_cte_targets_", + "\"__ggsql_cte_sales_", // qualified reference sales.date + "\"__ggsql_cte_targets_", // qualified reference targets.revenue ], None, ), @@ -379,7 +379,7 @@ mod tests { ( "WHERE sales.date > '2024-01-01' AND sales.revenue > 100", vec!["sales"], - vec!["__ggsql_cte_sales_"], + vec!["\"__ggsql_cte_sales_"], None, ), // No matching CTE (unchanged) diff --git a/src/execute/layer.rs b/src/execute/layer.rs index 409dffc3..6646f80f 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -54,7 +54,7 @@ pub fn layer_source_query( None => { // Layer uses global data debug_assert!(has_global, "Layer has no source and no global data"); - Ok(format!("SELECT * FROM {}", naming::global_table())) + Ok(format!("SELECT * FROM \"{}\"", naming::global_table())) } } } @@ -314,15 +314,15 @@ pub fn apply_pre_stat_transform( .filter(|col| seen.insert(&col.name)) .map(|col| { if let Some((_, sql)) = transform_exprs.iter().find(|(c, _)| c == &col.name) { - format!("{} AS \"{}\"", sql, col.name) + format!("{} AS {}", sql, naming::quote_ident(&col.name)) } else { - format!("\"{}\"", col.name) + naming::quote_ident(&col.name) } }) .collect(); format!( - "SELECT {} FROM ({}) AS __ggsql_pre__", + "SELECT {} FROM ({}) AS \"__ggsql_pre__\"", select_exprs.join(", "), query ) @@ -374,14 +374,14 @@ pub fn build_layer_base_query( // Build query with optional WHERE clause if let Some(ref f) = layer.filter { format!( - "SELECT {} FROM ({}) AS __ggsql_src__ WHERE {}", + "SELECT {} FROM ({}) AS \"__ggsql_src__\" WHERE {}", select_clause, source_query, f.as_str() ) } else { format!( - "SELECT {} FROM ({}) AS __ggsql_src__", + "SELECT {} FROM ({}) AS \"__ggsql_src__\"", select_clause, source_query ) } @@ -620,7 +620,7 @@ where transformed_query } else { format!( - "SELECT *, {} FROM ({}) AS __ggsql_stat__", + "SELECT *, {} FROM ({}) AS \"__ggsql_stat__\"", stat_rename_exprs.join(", "), transformed_query ) diff --git a/src/execute/mod.rs b/src/execute/mod.rs index 9e4ef350..09c84c14 100644 --- a/src/execute/mod.rs +++ b/src/execute/mod.rs @@ -905,7 +905,7 @@ pub struct PreparedData { /// # Arguments /// * `query` - The full ggsql query string /// * `reader` - A Reader implementation for executing SQL -pub fn prepare_data_with_reader(query: &str, reader: &R) -> Result { +pub fn prepare_data_with_reader(query: &str, reader: &dyn Reader) -> Result { let execute_query = |sql: &str| reader.execute_sql(sql); let dialect = reader.dialect(); diff --git a/src/execute/schema.rs b/src/execute/schema.rs index ad80aa7f..2c514a78 100644 --- a/src/execute/schema.rs +++ b/src/execute/schema.rs @@ -30,7 +30,7 @@ pub fn build_minmax_query(source_query: &str, column_names: &[&str]) -> String { .collect(); format!( - "WITH __ggsql_source__ AS ({}) SELECT {} FROM __ggsql_source__ UNION ALL SELECT {} FROM __ggsql_source__", + "WITH \"__ggsql_source__\" AS ({}) SELECT {} FROM \"__ggsql_source__\" UNION ALL SELECT {} FROM \"__ggsql_source__\"", source_query, min_exprs.join(", "), max_exprs.join(", ") diff --git a/src/naming.rs b/src/naming.rs index 882f40dc..b36a05ec 100644 --- a/src/naming.rs +++ b/src/naming.rs @@ -224,6 +224,21 @@ pub fn aesthetic_column(aesthetic: &str) -> String { format!("{}{}{}", AES_PREFIX, aesthetic, GGSQL_SUFFIX) } +// ============================================================================ +// SQL Quoting +// ============================================================================ + +/// Double-quote a SQL identifier for case-preserving databases (e.g. Snowflake). +/// +/// # Example +/// ``` +/// use ggsql::naming; +/// assert_eq!(naming::quote_ident("__ggsql_aes_x__"), "\"__ggsql_aes_x__\""); +/// ``` +pub fn quote_ident(name: &str) -> String { + format!("\"{}\"", name) +} + // ============================================================================ // Detection Functions // ============================================================================ diff --git a/src/plot/layer/geom/bar.rs b/src/plot/layer/geom/bar.rs index cddda1d4..4915844b 100644 --- a/src/plot/layer/geom/bar.rs +++ b/src/plot/layer/geom/bar.rs @@ -178,19 +178,23 @@ fn stat_bar_count( if let Some(weight_col) = weight_value.column_name() { if schema_columns.contains(weight_col) { // weight column exists - use SUM (but still call it "count") - format!("SUM({}) AS {}", weight_col, stat_count) + format!( + "SUM({}) AS \"{}\"", + naming::quote_ident(weight_col), + stat_count + ) } else { // weight mapped but column doesn't exist - fall back to COUNT // (this shouldn't happen with upfront validation, but handle gracefully) - format!("COUNT(*) AS {}", stat_count) + format!("COUNT(*) AS \"{}\"", stat_count) } } else { // Shouldn't happen (not literal, not column), fall back to COUNT - format!("COUNT(*) AS {}", stat_count) + format!("COUNT(*) AS \"{}\"", stat_count) } } else { // weight not mapped - use COUNT - format!("COUNT(*) AS {}", stat_count) + format!("COUNT(*) AS \"{}\"", stat_count) }; // Build the query based on whether x is mapped or not @@ -200,13 +204,13 @@ fn stat_bar_count( let (grouped_select, final_select) = if group_by.is_empty() { ( format!( - "'{dummy}' AS {x}, {agg}", + "'{dummy}' AS \"{x}\", {agg}", dummy = stat_dummy_value, x = stat_x, agg = agg_expr ), format!( - "*, {count} * 1.0 / SUM({count}) OVER () AS {prop}", + "*, \"{count}\" * 1.0 / SUM(\"{count}\") OVER () AS \"{prop}\"", count = stat_count, prop = stat_proportion ), @@ -215,14 +219,14 @@ fn stat_bar_count( let grp_cols = group_by.join(", "); ( format!( - "{g}, '{dummy}' AS {x}, {agg}", + "{g}, '{dummy}' AS \"{x}\", {agg}", g = grp_cols, dummy = stat_dummy_value, x = stat_x, agg = agg_expr ), format!( - "*, {count} * 1.0 / SUM({count}) OVER (PARTITION BY {grp}) AS {prop}", + "*, \"{count}\" * 1.0 / SUM(\"{count}\") OVER (PARTITION BY {grp}) AS \"{prop}\"", count = stat_count, grp = grp_cols, prop = stat_proportion @@ -233,7 +237,7 @@ fn stat_bar_count( let query_str = if group_by.is_empty() { // No grouping at all - single aggregate format!( - "WITH __stat_src__ AS ({query}), __grouped__ AS (SELECT {grouped} FROM __stat_src__) SELECT {final} FROM __grouped__", + "WITH \"__stat_src__\" AS ({query}), \"__grouped__\" AS (SELECT {grouped} FROM \"__stat_src__\") SELECT {final} FROM \"__grouped__\"", query = query, grouped = grouped_select, final = final_select @@ -242,7 +246,7 @@ fn stat_bar_count( // Group by partition/facet variables only let group_cols = group_by.join(", "); format!( - "WITH __stat_src__ AS ({query}), __grouped__ AS (SELECT {grouped} FROM __stat_src__ GROUP BY {group}) SELECT {final} FROM __grouped__", + "WITH \"__stat_src__\" AS ({query}), \"__grouped__\" AS (SELECT {grouped} FROM \"__stat_src__\" GROUP BY {group}) SELECT {final} FROM \"__grouped__\"", query = query, grouped = grouped_select, group = group_cols, @@ -264,7 +268,7 @@ fn stat_bar_count( ) } else { // x is mapped - use existing logic with two-stage query - let x_col = x_col.unwrap(); + let x_col = naming::quote_ident(&x_col.unwrap()); // Build grouped columns (group_by includes partition_by + facet variables + x) let group_cols = if group_by.is_empty() { @@ -280,7 +284,7 @@ fn stat_bar_count( ( format!("{x}, {agg}", x = x_col, agg = agg_expr), format!( - "*, {count} * 1.0 / SUM({count}) OVER () AS {prop}", + "*, \"{count}\" * 1.0 / SUM(\"{count}\") OVER () AS \"{prop}\"", count = stat_count, prop = stat_proportion ), @@ -290,7 +294,7 @@ fn stat_bar_count( ( format!("{g}, {x}, {agg}", g = grp_cols, x = x_col, agg = agg_expr), format!( - "*, {count} * 1.0 / SUM({count}) OVER (PARTITION BY {grp}) AS {prop}", + "*, \"{count}\" * 1.0 / SUM(\"{count}\") OVER (PARTITION BY {grp}) AS \"{prop}\"", count = stat_count, grp = grp_cols, prop = stat_proportion @@ -299,7 +303,7 @@ fn stat_bar_count( }; let query_str = format!( - "WITH __stat_src__ AS ({query}), __grouped__ AS (SELECT {grouped} FROM __stat_src__ GROUP BY {group}) SELECT {final} FROM __grouped__", + "WITH \"__stat_src__\" AS ({query}), \"__grouped__\" AS (SELECT {grouped} FROM \"__stat_src__\" GROUP BY {group}) SELECT {final} FROM \"__grouped__\"", query = query, grouped = grouped_select, group = group_cols, diff --git a/src/plot/layer/geom/boxplot.rs b/src/plot/layer/geom/boxplot.rs index 89dcc921..54633d7c 100644 --- a/src/plot/layer/geom/boxplot.rs +++ b/src/plot/layer/geom/boxplot.rs @@ -169,12 +169,16 @@ fn boxplot_sql_compute_summary( coef: &f64, dialect: &dyn SqlDialect, ) -> String { - let groups_str = groups.join(", "); + let quoted_groups: Vec = groups.iter().map(|g| naming::quote_ident(g)).collect(); + let groups_str = quoted_groups.join(", "); let lower_expr = dialect.sql_greatest(&[&format!("q1 - {coef} * (q3 - q1)"), "min"]); let upper_expr = dialect.sql_least(&[&format!("q3 + {coef} * (q3 - q1)"), "max"]); let q1 = dialect.sql_percentile(value, 0.25, from, groups); let median = dialect.sql_percentile(value, 0.50, from, groups); let q3 = dialect.sql_percentile(value, 0.75, from, groups); + let qt = "\"__ggsql_qt__\""; + let fn_alias = "\"__ggsql_fn__\""; + let quoted_value = naming::quote_ident(value); format!( "SELECT *, @@ -188,14 +192,14 @@ fn boxplot_sql_compute_summary( {q1} AS q1, {median} AS median, {q3} AS q3 - FROM ({from}) AS __ggsql_qt__ + FROM ({from}) AS {qt} WHERE {value} IS NOT NULL GROUP BY {groups} - ) AS __ggsql_fn__", + ) AS {fn_alias}", lower_expr = lower_expr, upper_expr = upper_expr, groups = groups_str, - value = value, + value = quoted_value, from = from, q1 = q1, median = median, @@ -207,10 +211,12 @@ fn boxplot_sql_filter_outliers(groups: &[String], value: &str, from: &str) -> St let mut join_pairs = Vec::new(); let mut keep_columns = Vec::new(); for column in groups { - join_pairs.push(format!("raw.{} = summary.{}", column, column)); - keep_columns.push(format!("raw.{}", column)); + let quoted = naming::quote_ident(column); + join_pairs.push(format!("raw.{} = summary.{}", quoted, quoted)); + keep_columns.push(format!("raw.{}", quoted)); } + let quoted_value = naming::quote_ident(value); // We're joining outliers with the summary to use the lower/upper whisker // values as a filter format!( @@ -221,7 +227,7 @@ fn boxplot_sql_filter_outliers(groups: &[String], value: &str, from: &str) -> St FROM ({from}) raw JOIN summary ON {pairs} WHERE raw.{value} NOT BETWEEN summary.lower AND summary.upper", - value = value, + value = quoted_value, groups = keep_columns.join(", "), pairs = join_pairs.join(" AND "), from = from @@ -239,19 +245,20 @@ fn boxplot_sql_append_outliers( let value2_name = naming::stat_column("value2"); let type_name = naming::stat_column("type"); - let groups_str = groups.join(", "); + let quoted_groups: Vec = groups.iter().map(|g| naming::quote_ident(g)).collect(); + let groups_str = quoted_groups.join(", "); // Helper to build visual-element rows from summary table // Each row type maps to one visual element with y and yend where needed let build_summary_select = |table: &str| { format!( - "SELECT {groups}, 'lower_whisker' AS {type_name}, q1 AS {value_name}, lower AS {value2_name} FROM {table} + "SELECT {groups}, 'lower_whisker' AS \"{type_name}\", q1 AS \"{value_name}\", lower AS \"{value2_name}\" FROM {table} UNION ALL - SELECT {groups}, 'upper_whisker' AS {type_name}, q3 AS {value_name}, upper AS {value2_name} FROM {table} + SELECT {groups}, 'upper_whisker' AS \"{type_name}\", q3 AS \"{value_name}\", upper AS \"{value2_name}\" FROM {table} UNION ALL - SELECT {groups}, 'box' AS {type_name}, q1 AS {value_name}, q3 AS {value2_name} FROM {table} + SELECT {groups}, 'box' AS \"{type_name}\", q1 AS \"{value_name}\", q3 AS \"{value2_name}\" FROM {table} UNION ALL - SELECT {groups}, 'median' AS {type_name}, median AS {value_name}, NULL AS {value2_name} FROM {table}", + SELECT {groups}, 'median' AS \"{type_name}\", median AS \"{value_name}\", NULL AS \"{value2_name}\" FROM {table}", groups = groups_str, type_name = type_name, value_name = value_name, @@ -282,7 +289,7 @@ fn boxplot_sql_append_outliers( ) {summary_select} UNION ALL - SELECT {groups}, type AS {type_name}, value AS {value_name}, NULL AS {value2_name} + SELECT {groups}, type AS \"{type_name}\", value AS \"{value_name}\", NULL AS \"{value2_name}\" FROM outliers ", summary = from, @@ -306,14 +313,14 @@ mod tests { fn test_sql_compute_summary_basic() { let groups = vec!["category".to_string()]; let result = boxplot_sql_compute_summary("data", &groups, "value", &1.5, &AnsiDialect); - assert!(result.contains("NTILE(4) OVER (ORDER BY value)")); + assert!(result.contains("NTILE(4) OVER (ORDER BY \"value\")")); assert!(result.contains("AS q1")); assert!(result.contains("AS median")); assert!(result.contains("AS q3")); - assert!(result.contains("MIN(value) AS min")); - assert!(result.contains("MAX(value) AS max")); - assert!(result.contains("WHERE value IS NOT NULL")); - assert!(result.contains("GROUP BY category")); + assert!(result.contains("MIN(\"value\") AS min")); + assert!(result.contains("MAX(\"value\") AS max")); + assert!(result.contains("WHERE \"value\" IS NOT NULL")); + assert!(result.contains("GROUP BY \"category\"")); assert!(result.contains("CASE WHEN (q1 - 1.5")); assert!(result.contains("CASE WHEN (q3 + 1.5")); } @@ -322,8 +329,8 @@ mod tests { fn test_sql_compute_summary_multiple_groups() { let groups = vec!["cat".to_string(), "region".to_string()]; let result = boxplot_sql_compute_summary("tbl", &groups, "val", &1.5, &AnsiDialect); - assert!(result.contains("GROUP BY cat, region")); - assert!(result.contains("NTILE(4) OVER (ORDER BY val)")); + assert!(result.contains("GROUP BY \"cat\", \"region\"")); + assert!(result.contains("NTILE(4) OVER (ORDER BY \"val\")")); } #[test] @@ -344,8 +351,8 @@ mod tests { let groups = vec!["cat".to_string(), "region".to_string()]; let result = boxplot_sql_filter_outliers(&groups, "value", "raw_data"); assert!(result.contains("JOIN summary ON")); - assert!(result.contains("raw.cat = summary.cat")); - assert!(result.contains("raw.region = summary.region")); + assert!(result.contains("raw.\"cat\" = summary.\"cat\"")); + assert!(result.contains("raw.\"region\" = summary.\"region\"")); assert!(result.contains("NOT BETWEEN summary.lower AND summary.upper")); assert!(result.contains("'outlier' AS type")); } @@ -373,16 +380,16 @@ mod tests { (CASE WHEN (q3 + 1.5 * (q3 - q1)) <= (max) THEN (q3 + 1.5 * (q3 - q1)) ELSE (max) END) AS upper FROM ( SELECT - category, - MIN(price) AS min, - MAX(price) AS max, + "category", + MIN("price") AS min, + MAX("price") AS max, {q1} AS q1, {median} AS median, {q3} AS q3 - FROM (SELECT * FROM sales) AS __ggsql_qt__ - WHERE price IS NOT NULL - GROUP BY category - ) AS __ggsql_fn__"# + FROM (SELECT * FROM sales) AS "__ggsql_qt__" + WHERE "price" IS NOT NULL + GROUP BY "category" + ) AS "__ggsql_fn__""# ); assert_eq!(result, expected); @@ -409,16 +416,16 @@ mod tests { (CASE WHEN (q3 + 1.5 * (q3 - q1)) <= (max) THEN (q3 + 1.5 * (q3 - q1)) ELSE (max) END) AS upper FROM ( SELECT - region, product, - MIN(revenue) AS min, - MAX(revenue) AS max, + "region", "product", + MIN("revenue") AS min, + MAX("revenue") AS max, {q1} AS q1, {median} AS median, {q3} AS q3 - FROM (SELECT * FROM data) AS __ggsql_qt__ - WHERE revenue IS NOT NULL - GROUP BY region, product - ) AS __ggsql_fn__"# + FROM (SELECT * FROM data) AS "__ggsql_qt__" + WHERE "revenue" IS NOT NULL + GROUP BY "region", "product" + ) AS "__ggsql_fn__""# ); assert_eq!(result, expected); @@ -445,9 +452,9 @@ mod tests { assert!(result.contains("'median'")); // Check column names - assert!(result.contains(&format!("AS {}", naming::stat_column("value")))); - assert!(result.contains(&format!("AS {}", naming::stat_column("value2")))); - assert!(result.contains(&format!("AS {}", naming::stat_column("type")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("value")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("value2")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("type")))); } #[test] @@ -469,9 +476,9 @@ mod tests { assert!(result.contains("'median'")); // Check column names - assert!(result.contains(&format!("AS {}", naming::stat_column("value")))); - assert!(result.contains(&format!("AS {}", naming::stat_column("value2")))); - assert!(result.contains(&format!("AS {}", naming::stat_column("type")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("value")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("value2")))); + assert!(result.contains(&format!("AS \"{}\"", naming::stat_column("type")))); } #[test] @@ -481,8 +488,8 @@ mod tests { let raw = "(SELECT * FROM raw_data)"; let result = boxplot_sql_append_outliers(summary, &groups, "val", raw, &true); - // Verify all groups are present - assert!(result.contains("cat, region, year")); + // Verify all groups are present (quoted) + assert!(result.contains("\"cat\", \"region\", \"year\"")); // Check structure assert!(result.contains("WITH")); @@ -491,9 +498,9 @@ mod tests { // Verify outlier join conditions for all groups let outlier_section = result.split("outliers AS").nth(1).unwrap(); - assert!(outlier_section.contains("raw.cat = summary.cat")); - assert!(outlier_section.contains("raw.region = summary.region")); - assert!(outlier_section.contains("raw.year = summary.year")); + assert!(outlier_section.contains("raw.\"cat\" = summary.\"cat\"")); + assert!(outlier_section.contains("raw.\"region\" = summary.\"region\"")); + assert!(outlier_section.contains("raw.\"year\" = summary.\"year\"")); } // ==================== GeomTrait Implementation Tests ==================== diff --git a/src/plot/layer/geom/density.rs b/src/plot/layer/geom/density.rs index e67e0cbe..2653e448 100644 --- a/src/plot/layer/geom/density.rs +++ b/src/plot/layer/geom/density.rs @@ -229,13 +229,15 @@ fn density_sql_bandwidth( let (groups_select, group_by) = if groups.is_empty() { (String::new(), String::new()) } else { - let groups_str = groups.join(", "); + let quoted_groups: Vec = groups.iter().map(|g| naming::quote_ident(g)).collect(); + let groups_str = quoted_groups.join(", "); ( format!("\n {},", groups_str), format!("\n GROUP BY {}", groups_str), ) }; + let quoted_value = naming::quote_ident(value); format!( "WITH RECURSIVE bandwidth AS ( @@ -243,12 +245,12 @@ fn density_sql_bandwidth( {bw_expr} AS bw,{groups_select} MIN({value}) AS x_min, MAX({value}) AS x_max - FROM ({from}) AS __ggsql_qt__ + FROM ({from}) AS \"__ggsql_qt__\" WHERE {value} IS NOT NULL{group_by} )", bw_expr = bw_expr, groups_select = groups_select, - value = value, + value = quoted_value, from = from, group_by = group_by ) @@ -264,7 +266,8 @@ fn silverman_rule( // The query computes Silverman's rule of thumb (R's `stats::bw.nrd0()`). // We absorb the adjustment in the 0.9 multiplier of the rule let adjust = 0.9 * adjust; - let stddev = format!("SQRT(AVG({v}*{v}) - AVG({v})*AVG({v}))", v = value_column); + let v = naming::quote_ident(value_column); + let stddev = format!("SQRT(AVG({v}*{v}) - AVG({v})*AVG({v}))", v = v); let q75 = dialect.sql_percentile(value_column, 0.75, from, groups); let q25 = dialect.sql_percentile(value_column, 0.25, from, groups); let iqr = format!("({q75} - {q25}) / 1.34"); @@ -351,34 +354,36 @@ fn build_data_cte( ) -> String { // Include weight column if provided, otherwise default to 1.0 let weight_col = if let Some(w) = weight { - format!(", {} AS weight", w) + format!(", {} AS weight", naming::quote_ident(w)) } else { ", 1.0 AS weight".to_string() }; let smooth_col = if let Some(s) = smooth { - format!(", {}", s) + format!(", {}", naming::quote_ident(s)) } else { "".to_string() }; + let quoted_value = naming::quote_ident(value); // Only filter out nulls in value column, keep NULLs in group columns - let mut filter_valid = format!("{} IS NOT NULL", value); + let mut filter_valid = format!("{} IS NOT NULL", quoted_value); if let Some(s) = smooth { filter_valid = format!( - "{filter} AND {smth} IS NOT NULL", + "{filter} AND {} IS NOT NULL", + naming::quote_ident(s), filter = filter_valid, - smth = s ); } + let quoted_groups: Vec = group_by.iter().map(|g| naming::quote_ident(g)).collect(); format!( "data AS ( SELECT {groups}{value} AS val{weight_col}{smooth_col} FROM ({from}) WHERE {filter_valid} )", - groups = with_trailing_comma(&group_by.join(", ")), - value = value, + groups = with_trailing_comma("ed_groups.join(", ")), + value = quoted_value, weight_col = weight_col, smooth_col = smooth_col, from = from, @@ -420,12 +425,13 @@ fn build_grid_cte( "grid AS ( SELECT {x_formula} AS x FROM global_range AS global - CROSS JOIN __ggsql_seq__ AS seq + CROSS JOIN \"__ggsql_seq__\" AS seq )", x_formula = x_formula ) } else { - let groups_str = groups.join(", "); + let quoted_groups: Vec = groups.iter().map(|g| naming::quote_ident(g)).collect(); + let groups_str = quoted_groups.join(", "); // When tails is specified, create full_grid; otherwise create grid directly let cte_name = if tails.is_some() { "full_grid" } else { "grid" }; format!( @@ -434,7 +440,7 @@ fn build_grid_cte( {groups}, {x_formula} AS x FROM global_range AS global - CROSS JOIN __ggsql_seq__ AS seq + CROSS JOIN \"__ggsql_seq__\" AS seq CROSS JOIN (SELECT DISTINCT {groups} FROM bandwidth) AS groups )", cte_name = cte_name, @@ -449,14 +455,14 @@ fn build_grid_cte( let bandwidth_join_conds: Vec = groups .iter() .map(|g| { - format!( - "full_grid.{col} IS NOT DISTINCT FROM bandwidth.{col}", - col = g - ) + let q = naming::quote_ident(g); + format!("full_grid.{q} IS NOT DISTINCT FROM bandwidth.{q}") }) .collect(); - let grid_groups_select: Vec = - groups.iter().map(|g| format!("full_grid.{}", g)).collect(); + let grid_groups_select: Vec = groups + .iter() + .map(|g| format!("full_grid.{}", naming::quote_ident(g))) + .collect(); format!( "{seq_cte}, @@ -513,7 +519,10 @@ fn compute_density( } else { group_by .iter() - .map(|g| format!("data.{col} IS NOT DISTINCT FROM bandwidth.{col}", col = g)) + .map(|g| { + let q = naming::quote_ident(g); + format!("data.{q} IS NOT DISTINCT FROM bandwidth.{q}") + }) .collect::>() .join(" AND ") }; @@ -524,7 +533,10 @@ fn compute_density( } else { let grid_data_conds: Vec = group_by .iter() - .map(|g| format!("grid.{col} IS NOT DISTINCT FROM data.{col}", col = g)) + .map(|g| { + let q = naming::quote_ident(g); + format!("grid.{q} IS NOT DISTINCT FROM data.{q}") + }) .collect(); format!("WHERE {}", grid_data_conds.join(" AND ")) }; @@ -538,7 +550,10 @@ fn compute_density( ); // Build group-related SQL fragments - let grid_groups: Vec = group_by.iter().map(|g| format!("grid.{}", g)).collect(); + let grid_groups: Vec = group_by + .iter() + .map(|g| format!("grid.{}", naming::quote_ident(g))) + .collect(); let aggregation = format!( "GROUP BY grid.x{grid_group_by} ORDER BY grid.x{grid_group_by}", @@ -548,35 +563,40 @@ fn compute_density( let groups = if group_by.is_empty() { String::new() } else { - format!("{},", group_by.join(", ")) + let quoted: Vec = group_by.iter().map(|g| naming::quote_ident(g)).collect(); + format!("{},", quoted.join(", ")) }; + let x_column = naming::stat_column(value_aesthetic); + let intensity_column = naming::stat_column("intensity"); + let density_column = naming::stat_column("density"); + // Generate the density computation query format!( "{bandwidth_cte}, {data_cte}, {grid_cte} SELECT - {x_column}, + \"{x_column}\", {groups} - {intensity_column}, - {intensity_column} / __norm AS {density_column} + \"{intensity_column}\", + \"{intensity_column}\" / \"__norm\" AS \"{density_column}\" FROM ( SELECT - grid.x AS {x_column}, + grid.x AS \"{x_column}\", {grid_groups} - {kernel} AS {intensity_column}, - SUM(data.weight) AS __norm + {kernel} AS \"{intensity_column}\", + SUM(data.weight) AS \"__norm\" {join_logic} {aggregation} )", bandwidth_cte = bandwidth_cte, data_cte = data_cte, grid_cte = grid_cte, - x_column = naming::stat_column(value_aesthetic), + x_column = x_column, groups = groups, - intensity_column = naming::stat_column("intensity"), - density_column = naming::stat_column("density"), + intensity_column = intensity_column, + density_column = density_column, aggregation = aggregation, grid_groups = with_trailing_comma(&grid_groups.join(", ")) ) @@ -606,21 +626,21 @@ mod tests { let kernel = choose_kde_kernel(¶meters, None).expect("kernel should be valid"); let sql = compute_density("x", &groups, kernel, &bw_cte, &data_cte, &grid_cte); - let expected = "WITH RECURSIVE + let expected = r#"WITH RECURSIVE bandwidth AS ( SELECT 0.5 AS bw, - MIN(x) AS x_min, - MAX(x) AS x_max - FROM (SELECT x FROM (VALUES (1.0), (2.0), (3.0)) AS t(x)) AS __ggsql_qt__ - WHERE x IS NOT NULL + MIN("x") AS x_min, + MAX("x") AS x_max + FROM (SELECT x FROM (VALUES (1.0), (2.0), (3.0)) AS t(x)) AS "__ggsql_qt__" + WHERE "x" IS NOT NULL ), data AS ( - SELECT x AS val, 1.0 AS weight + SELECT "x" AS val, 1.0 AS weight FROM (SELECT x FROM (VALUES (1.0), (2.0), (3.0)) AS t(x)) - WHERE x IS NOT NULL + WHERE "x" IS NOT NULL ), - __ggsql_base__(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM __ggsql_base__ WHERE n < 7),__ggsql_seq__(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM __ggsql_base__ a, __ggsql_base__ b, __ggsql_base__ c WHERE a.n * 64 + b.n * 8 + c.n < 512), + "__ggsql_base__"(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM "__ggsql_base__" WHERE n < 7),"__ggsql_seq__"(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM "__ggsql_base__" a, "__ggsql_base__" b, "__ggsql_base__" c WHERE a.n * 64 + b.n * 8 + c.n < 512), global_range AS ( SELECT MIN(x_min) AS min, MAX(x_max) AS max, 3 * MAX(bw) AS expansion FROM bandwidth @@ -628,23 +648,23 @@ mod tests { grid AS ( SELECT (global.min - global.expansion) + (seq.n * ((global.max - global.min) + 2 * global.expansion) / 511) AS x FROM global_range AS global - CROSS JOIN __ggsql_seq__ AS seq + CROSS JOIN "__ggsql_seq__" AS seq ) SELECT - __ggsql_stat_x, - __ggsql_stat_intensity, - __ggsql_stat_intensity / __norm AS __ggsql_stat_density + "__ggsql_stat_x", + "__ggsql_stat_intensity", + "__ggsql_stat_intensity" / "__norm" AS "__ggsql_stat_density" FROM ( SELECT - grid.x AS __ggsql_stat_x, - SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS __ggsql_stat_intensity, - SUM(data.weight) AS __norm + grid.x AS "__ggsql_stat_x", + SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS "__ggsql_stat_intensity", + SUM(data.weight) AS "__norm" FROM data INNER JOIN bandwidth ON true CROSS JOIN grid GROUP BY grid.x ORDER BY grid.x - )"; + )"#; // Normalize whitespace for comparison let normalize = |s: &str| s.split_whitespace().collect::>().join(" "); @@ -682,53 +702,53 @@ mod tests { let kernel = choose_kde_kernel(¶meters, None).expect("kernel should be valid"); let sql = compute_density("x", &groups, kernel, &bw_cte, &data_cte, &grid_cte); - let expected = "WITH RECURSIVE + let expected = r#"WITH RECURSIVE bandwidth AS ( SELECT 0.5 AS bw, - region, category, - MIN(x) AS x_min, - MAX(x) AS x_max - FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) AS __ggsql_qt__ - WHERE x IS NOT NULL - GROUP BY region, category + "region", "category", + MIN("x") AS x_min, + MAX("x") AS x_max + FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) AS "__ggsql_qt__" + WHERE "x" IS NOT NULL + GROUP BY "region", "category" ), data AS ( - SELECT region, category, x AS val, 1.0 AS weight + SELECT "region", "category", "x" AS val, 1.0 AS weight FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) - WHERE x IS NOT NULL + WHERE "x" IS NOT NULL ), - __ggsql_base__(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM __ggsql_base__ WHERE n < 7),__ggsql_seq__(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM __ggsql_base__ a, __ggsql_base__ b, __ggsql_base__ c WHERE a.n * 64 + b.n * 8 + c.n < 512), + "__ggsql_base__"(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM "__ggsql_base__" WHERE n < 7),"__ggsql_seq__"(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM "__ggsql_base__" a, "__ggsql_base__" b, "__ggsql_base__" c WHERE a.n * 64 + b.n * 8 + c.n < 512), global_range AS ( SELECT MIN(x_min) AS min, MAX(x_max) AS max, 3 * MAX(bw) AS expansion FROM bandwidth ), grid AS ( SELECT - region, category, + "region", "category", (global.min - global.expansion) + (seq.n * ((global.max - global.min) + 2 * global.expansion) / 511) AS x FROM global_range AS global - CROSS JOIN __ggsql_seq__ AS seq - CROSS JOIN (SELECT DISTINCT region, category FROM bandwidth) AS groups + CROSS JOIN "__ggsql_seq__" AS seq + CROSS JOIN (SELECT DISTINCT "region", "category" FROM bandwidth) AS groups ) SELECT - __ggsql_stat_x, - region, category, - __ggsql_stat_intensity, - __ggsql_stat_intensity / __norm AS __ggsql_stat_density + "__ggsql_stat_x", + "region", "category", + "__ggsql_stat_intensity", + "__ggsql_stat_intensity" / "__norm" AS "__ggsql_stat_density" FROM ( SELECT - grid.x AS __ggsql_stat_x, - grid.region, grid.category, - SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS __ggsql_stat_intensity, - SUM(data.weight) AS __norm + grid.x AS "__ggsql_stat_x", + grid."region", grid."category", + SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS "__ggsql_stat_intensity", + SUM(data.weight) AS "__norm" FROM data - INNER JOIN bandwidth ON data.region IS NOT DISTINCT FROM bandwidth.region AND data.category IS NOT DISTINCT FROM bandwidth.category + INNER JOIN bandwidth ON data."region" IS NOT DISTINCT FROM bandwidth."region" AND data."category" IS NOT DISTINCT FROM bandwidth."category" CROSS JOIN grid - WHERE grid.region IS NOT DISTINCT FROM data.region AND grid.category IS NOT DISTINCT FROM data.category - GROUP BY grid.x, grid.region, grid.category - ORDER BY grid.x, grid.region, grid.category - )"; + WHERE grid."region" IS NOT DISTINCT FROM data."region" AND grid."category" IS NOT DISTINCT FROM data."category" + GROUP BY grid.x, grid."region", grid."category" + ORDER BY grid.x, grid."region", grid."category" + )"#; // Normalize whitespace for comparison let normalize = |s: &str| s.split_whitespace().collect::>().join(" "); @@ -822,7 +842,7 @@ mod tests { // Verify SQL uses NTILE-based percentile subqueries with grouping assert!(bw_cte.contains("NTILE(4)")); - assert!(bw_cte.contains("GROUP BY region")); + assert!(bw_cte.contains("GROUP BY \"region\"")); let expected_rule = silverman_rule(1.0, "x", query, &groups, &AnsiDialect); assert!(normalize(&bw_cte).contains(&normalize(&expected_rule))); diff --git a/src/plot/layer/geom/histogram.rs b/src/plot/layer/geom/histogram.rs index fef6fb2e..9176956e 100644 --- a/src/plot/layer/geom/histogram.rs +++ b/src/plot/layer/geom/histogram.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; -use super::types::{get_column_name, CLOSED_VALUES, POSITION_VALUES}; +use super::types::{get_quoted_column_name, CLOSED_VALUES, POSITION_VALUES}; use super::{ DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, StatResult, @@ -125,7 +125,7 @@ fn stat_histogram( dialect: &dyn SqlDialect, ) -> Result { // Get x column name from aesthetics - let x_col = get_column_name(aesthetics, "pos1").ok_or_else(|| { + let x_col = get_quoted_column_name(aesthetics, "pos1").ok_or_else(|| { GgsqlError::ValidationError("Histogram requires 'x' aesthetic mapping".to_string()) })?; @@ -149,7 +149,7 @@ fn stat_histogram( // Query min/max to compute bin width let stats_query = format!( - "SELECT MIN({x}) as min_val, MAX({x}) as max_val FROM ({query}) AS __ggsql_stats__", + "SELECT MIN({x}) as min_val, MAX({x}) as max_val FROM ({query}) AS \"__ggsql_stats__\"", x = x_col, query = query ); @@ -213,7 +213,7 @@ fn stat_histogram( )); } if let Some(weight_col) = weight_value.column_name() { - format!("SUM({})", weight_col) + format!("SUM({})", naming::quote_ident(weight_col)) } else { "COUNT(*)".to_string() } @@ -232,11 +232,11 @@ fn stat_histogram( let (binned_select, final_select) = if group_by.is_empty() { ( format!( - "{} AS {}, {} AS {}, {} AS {}", + "{} AS \"{}\", {} AS \"{}\", {} AS \"{}\"", bin_expr, stat_bin, bin_end_expr, stat_bin_end, agg_expr, stat_count ), format!( - "*, {count} * 1.0 / SUM({count}) OVER () AS {density}", + "*, \"{count}\" * 1.0 / SUM(\"{count}\") OVER () AS \"{density}\"", count = stat_count, density = stat_density ), @@ -245,11 +245,11 @@ fn stat_histogram( let grp_cols = group_by.join(", "); ( format!( - "{}, {} AS {}, {} AS {}, {} AS {}", + "{}, {} AS \"{}\", {} AS \"{}\", {} AS \"{}\"", grp_cols, bin_expr, stat_bin, bin_end_expr, stat_bin_end, agg_expr, stat_count ), format!( - "*, {count} * 1.0 / SUM({count}) OVER (PARTITION BY {grp}) AS {density}", + "*, \"{count}\" * 1.0 / SUM(\"{count}\") OVER (PARTITION BY {grp}) AS \"{density}\"", count = stat_count, grp = grp_cols, density = stat_density @@ -258,7 +258,7 @@ fn stat_histogram( }; let transformed_query = format!( - "WITH __stat_src__ AS ({query}), __binned__ AS (SELECT {binned} FROM __stat_src__ GROUP BY {group}) SELECT {final} FROM __binned__", + "WITH \"__stat_src__\" AS ({query}), \"__binned__\" AS (SELECT {binned} FROM \"__stat_src__\" GROUP BY {group}) SELECT {final} FROM \"__binned__\"", query = query, binned = binned_select, group = group_cols, diff --git a/src/plot/layer/geom/rect.rs b/src/plot/layer/geom/rect.rs index fdd2584a..6aeeb928 100644 --- a/src/plot/layer/geom/rect.rs +++ b/src/plot/layer/geom/rect.rs @@ -2,8 +2,8 @@ use std::collections::HashMap; -use super::types::get_column_name; use super::types::POSITION_VALUES; +use super::types::{get_column_name, get_quoted_column_name}; use super::{DefaultAesthetics, GeomTrait, GeomType, ParamConstraint, StatResult}; use crate::naming; use crate::plot::types::{DefaultAestheticValue, ParameterValue}; @@ -130,15 +130,17 @@ fn process_direction( _ => unreachable!("axis must be 'x' or 'y'"), }; - // Get column names from MAPPING, with SETTING fallback for size - let center = get_column_name(aesthetics, center_aes); - let min = get_column_name(aesthetics, min_aes); - let max = get_column_name(aesthetics, max_aes); - let size = get_column_name(aesthetics, size_aes) + // Get unquoted center name for schema lookup + let center_unquoted = get_column_name(aesthetics, center_aes); + let center = center_unquoted.as_deref().map(naming::quote_ident); + let min = get_quoted_column_name(aesthetics, min_aes); + let max = get_quoted_column_name(aesthetics, max_aes); + // SETTING fallback for size is a literal value, no quoting needed. + let size = get_quoted_column_name(aesthetics, size_aes) .or_else(|| parameters.get(size_aes).map(|v| v.to_string())); // Detect if discrete by checking schema - let is_discrete = center + let is_discrete = center_unquoted .as_ref() .and_then(|col| schema.iter().find(|c| &c.name == col)) .map(|c| c.is_discrete) @@ -172,8 +174,8 @@ fn process_direction( // Build SELECT parts using the stat columns let select_parts = vec![ - format!("{} AS {}", expr_1, naming::stat_column(&stat_cols[0])), - format!("{} AS {}", expr_2, naming::stat_column(&stat_cols[1])), + format!("{} AS \"{}\"", expr_1, naming::stat_column(&stat_cols[0])), + format!("{} AS \"{}\"", expr_2, naming::stat_column(&stat_cols[1])), ]; Ok((select_parts, stat_cols)) @@ -208,7 +210,7 @@ fn stat_rect( let mut select_parts: Vec = schema .iter() .filter(|col| !consumed_columns.contains(&col.name)) - .map(|col| col.name.clone()) + .map(|col| naming::quote_ident(&col.name)) .collect(); // Add X direction SELECT parts and collect stat columns @@ -223,7 +225,7 @@ fn stat_rect( // Build transformed query let transformed_query = format!( - "SELECT {} FROM ({}) AS __ggsql_rect_stat__", + "SELECT {} FROM ({}) AS \"__ggsql_rect_stat__\"", select_list, query ); @@ -446,44 +448,44 @@ mod tests { ( "xmin + xmax", vec!["pos1min", "pos1max"], - "__ggsql_aes_pos1min__", - "__ggsql_aes_pos1max__", + "\"__ggsql_aes_pos1min__\"", + "\"__ggsql_aes_pos1max__\"", ), ( "x + width", vec!["pos1", "width"], - "(__ggsql_aes_pos1__ - __ggsql_aes_width__ / 2.0)", - "(__ggsql_aes_pos1__ + __ggsql_aes_width__ / 2.0)", + "(\"__ggsql_aes_pos1__\" - \"__ggsql_aes_width__\" / 2.0)", + "(\"__ggsql_aes_pos1__\" + \"__ggsql_aes_width__\" / 2.0)", ), ( "x only (default width 1.0)", vec!["pos1"], - "(__ggsql_aes_pos1__ - 0.5)", - "(__ggsql_aes_pos1__ + 0.5)", + "(\"__ggsql_aes_pos1__\" - 0.5)", + "(\"__ggsql_aes_pos1__\" + 0.5)", ), ( "x + xmin", vec!["pos1", "pos1min"], - "__ggsql_aes_pos1min__", - "(2 * __ggsql_aes_pos1__ - __ggsql_aes_pos1min__)", + "\"__ggsql_aes_pos1min__\"", + "(2 * \"__ggsql_aes_pos1__\" - \"__ggsql_aes_pos1min__\")", ), ( "x + xmax", vec!["pos1", "pos1max"], - "(2 * __ggsql_aes_pos1__ - __ggsql_aes_pos1max__)", - "__ggsql_aes_pos1max__", + "(2 * \"__ggsql_aes_pos1__\" - \"__ggsql_aes_pos1max__\")", + "\"__ggsql_aes_pos1max__\"", ), ( "xmin + width", vec!["pos1min", "width"], - "__ggsql_aes_pos1min__", - "(__ggsql_aes_pos1min__ + __ggsql_aes_width__)", + "\"__ggsql_aes_pos1min__\"", + "(\"__ggsql_aes_pos1min__\" + \"__ggsql_aes_width__\")", ), ( "xmax + width", vec!["pos1max", "width"], - "(__ggsql_aes_pos1max__ - __ggsql_aes_width__)", - "__ggsql_aes_pos1max__", + "(\"__ggsql_aes_pos1max__\" - \"__ggsql_aes_width__\")", + "\"__ggsql_aes_pos1max__\"", ), ]; @@ -522,7 +524,7 @@ mod tests { let stat_pos1min = naming::stat_column("pos1min"); let stat_pos1max = naming::stat_column("pos1max"); assert!( - query.contains(&format!("{} AS {}", expected_min, stat_pos1min)), + query.contains(&format!("{} AS \"{}\"", expected_min, stat_pos1min)), "{}: Expected '{} AS {}' in query, got: {}", name, expected_min, @@ -530,7 +532,7 @@ mod tests { query ); assert!( - query.contains(&format!("{} AS {}", expected_max, stat_pos1max)), + query.contains(&format!("{} AS \"{}\"", expected_max, stat_pos1max)), "{}: Expected '{} AS {}' in query, got: {}", name, expected_max, @@ -562,38 +564,38 @@ mod tests { ( "ymin + ymax", vec!["pos2min", "pos2max"], - "__ggsql_aes_pos2min__", - "__ggsql_aes_pos2max__", + "\"__ggsql_aes_pos2min__\"", + "\"__ggsql_aes_pos2max__\"", ), ( "y + height", vec!["pos2", "height"], - "(__ggsql_aes_pos2__ - __ggsql_aes_height__ / 2.0)", - "(__ggsql_aes_pos2__ + __ggsql_aes_height__ / 2.0)", + "(\"__ggsql_aes_pos2__\" - \"__ggsql_aes_height__\" / 2.0)", + "(\"__ggsql_aes_pos2__\" + \"__ggsql_aes_height__\" / 2.0)", ), ( "y + ymin", vec!["pos2", "pos2min"], - "__ggsql_aes_pos2min__", - "(2 * __ggsql_aes_pos2__ - __ggsql_aes_pos2min__)", + "\"__ggsql_aes_pos2min__\"", + "(2 * \"__ggsql_aes_pos2__\" - \"__ggsql_aes_pos2min__\")", ), ( "y + ymax", vec!["pos2", "pos2max"], - "(2 * __ggsql_aes_pos2__ - __ggsql_aes_pos2max__)", - "__ggsql_aes_pos2max__", + "(2 * \"__ggsql_aes_pos2__\" - \"__ggsql_aes_pos2max__\")", + "\"__ggsql_aes_pos2max__\"", ), ( "ymin + height", vec!["pos2min", "height"], - "__ggsql_aes_pos2min__", - "(__ggsql_aes_pos2min__ + __ggsql_aes_height__)", + "\"__ggsql_aes_pos2min__\"", + "(\"__ggsql_aes_pos2min__\" + \"__ggsql_aes_height__\")", ), ( "ymax + height", vec!["pos2max", "height"], - "(__ggsql_aes_pos2max__ - __ggsql_aes_height__)", - "__ggsql_aes_pos2max__", + "(\"__ggsql_aes_pos2max__\" - \"__ggsql_aes_height__\")", + "\"__ggsql_aes_pos2max__\"", ), ]; @@ -632,7 +634,7 @@ mod tests { let stat_pos2min = naming::stat_column("pos2min"); let stat_pos2max = naming::stat_column("pos2max"); assert!( - query.contains(&format!("{} AS {}", expected_min, stat_pos2min)), + query.contains(&format!("{} AS \"{}\"", expected_min, stat_pos2min)), "{}: Expected '{} AS {}' in query, got: {}", name, expected_min, @@ -640,7 +642,7 @@ mod tests { query ); assert!( - query.contains(&format!("{} AS {}", expected_max, stat_pos2max)), + query.contains(&format!("{} AS \"{}\"", expected_max, stat_pos2max)), "{}: Expected '{} AS {}' in query, got: {}", name, expected_max, @@ -687,8 +689,8 @@ mod tests { .. }) = result { - assert!(query.contains("__ggsql_aes_pos1__ AS __ggsql_stat_pos1")); - assert!(query.contains("__ggsql_aes_width__ AS __ggsql_stat_width")); + assert!(query.contains("\"__ggsql_aes_pos1__\" AS \"__ggsql_stat_pos1")); + assert!(query.contains("\"__ggsql_aes_width__\" AS \"__ggsql_stat_width")); assert!(stat_columns.contains(&"pos1".to_string())); assert!(stat_columns.contains(&"width".to_string())); assert!(stat_columns.contains(&"pos2min".to_string())); @@ -718,8 +720,8 @@ mod tests { .. }) = result { - assert!(query.contains("__ggsql_aes_pos2__ AS __ggsql_stat_pos2")); - assert!(query.contains("__ggsql_aes_height__ AS __ggsql_stat_height")); + assert!(query.contains("\"__ggsql_aes_pos2__\" AS \"__ggsql_stat_pos2")); + assert!(query.contains("\"__ggsql_aes_height__\" AS \"__ggsql_stat_height")); assert!(stat_columns.contains(&"pos1min".to_string())); assert!(stat_columns.contains(&"pos1max".to_string())); assert!(stat_columns.contains(&"pos2".to_string())); @@ -749,10 +751,10 @@ mod tests { .. }) = result { - assert!(query.contains("__ggsql_aes_pos1__ AS __ggsql_stat_pos1")); - assert!(query.contains("__ggsql_aes_width__ AS __ggsql_stat_width")); - assert!(query.contains("__ggsql_aes_pos2__ AS __ggsql_stat_pos2")); - assert!(query.contains("__ggsql_aes_height__ AS __ggsql_stat_height")); + assert!(query.contains("\"__ggsql_aes_pos1__\" AS \"__ggsql_stat_pos1")); + assert!(query.contains("\"__ggsql_aes_width__\" AS \"__ggsql_stat_width")); + assert!(query.contains("\"__ggsql_aes_pos2__\" AS \"__ggsql_stat_pos2")); + assert!(query.contains("\"__ggsql_aes_height__\" AS \"__ggsql_stat_height")); assert_eq!(stat_columns.len(), 4); } } @@ -782,8 +784,8 @@ mod tests { stat_columns, .. } => { - assert!(query.contains("(__ggsql_aes_pos1__ - 0.5)")); - assert!(query.contains("(__ggsql_aes_pos1__ + 0.5)")); + assert!(query.contains("(\"__ggsql_aes_pos1__\" - 0.5)")); + assert!(query.contains("(\"__ggsql_aes_pos1__\" + 0.5)")); assert!(stat_columns.contains(&"pos1min".to_string())); assert!(stat_columns.contains(&"pos1max".to_string())); } @@ -852,7 +854,7 @@ mod tests { stat_columns, .. } => { - assert!(query.contains("1.0 AS __ggsql_stat_width")); + assert!(query.contains("1.0 AS \"__ggsql_stat_width")); assert!(stat_columns.contains(&"width".to_string())); } _ => panic!("Expected Transformed"), @@ -879,12 +881,12 @@ mod tests { assert!(result.is_ok()); if let Ok(StatResult::Transformed { query, .. }) = result { - // Should include fill column (non-consumed aesthetic from schema) - assert!(query.contains("__ggsql_aes_fill__")); + // Should include fill column (non-consumed aesthetic from schema, quoted) + assert!(query.contains("\"__ggsql_aes_fill__\"")); // Should NOT include width/height as pass-through (they're consumed) // They should only appear as stat columns - assert!(query.contains("__ggsql_aes_width__ AS __ggsql_stat_width")); - assert!(query.contains("__ggsql_aes_height__ AS __ggsql_stat_height")); + assert!(query.contains("\"__ggsql_aes_width__\" AS \"__ggsql_stat_width")); + assert!(query.contains("\"__ggsql_aes_height__\" AS \"__ggsql_stat_height")); } } @@ -909,8 +911,8 @@ mod tests { if let Ok(StatResult::Transformed { query, .. }) = result { // Should use SETTING values as SQL literals - assert!(query.contains("0.7 AS __ggsql_stat_width")); - assert!(query.contains("0.9 AS __ggsql_stat_height")); + assert!(query.contains("0.7 AS \"__ggsql_stat_width")); + assert!(query.contains("0.9 AS \"__ggsql_stat_height")); } } } diff --git a/src/plot/layer/geom/smooth.rs b/src/plot/layer/geom/smooth.rs index fad14432..d81ef8ff 100644 --- a/src/plot/layer/geom/smooth.rs +++ b/src/plot/layer/geom/smooth.rs @@ -4,7 +4,7 @@ use super::types::POSITION_VALUES; use super::{ DefaultAesthetics, DefaultParamValue, GeomTrait, GeomType, ParamConstraint, ParamDefinition, }; -use crate::plot::geom::types::get_column_name; +use crate::plot::geom::types::get_quoted_column_name; use crate::plot::types::DefaultAestheticValue; use crate::plot::{ParameterValue, StatResult}; use crate::reader::SqlDialect; @@ -136,10 +136,10 @@ impl std::fmt::Display for Smooth { } fn stat_ols(query: &str, aesthetics: &Mappings, group_by: &[String]) -> Result { - let x_col = get_column_name(aesthetics, "pos1").ok_or_else(|| { + let x_col = get_quoted_column_name(aesthetics, "pos1").ok_or_else(|| { GgsqlError::ValidationError("Smooth requires 'pos1' aesthetic".to_string()) })?; - let y_col = get_column_name(aesthetics, "pos2").ok_or_else(|| { + let y_col = get_quoted_column_name(aesthetics, "pos2").ok_or_else(|| { GgsqlError::ValidationError("Smooth requires 'pos2' aesthetic".to_string()) })?; @@ -172,13 +172,13 @@ fn stat_ols(query: &str, aesthetics: &Mappings, group_by: &[String]) -> Result Result Result { - let x_col = get_column_name(aesthetics, "pos1").ok_or_else(|| { + let x_col = get_quoted_column_name(aesthetics, "pos1").ok_or_else(|| { GgsqlError::ValidationError("Smooth requires 'pos1' aesthetic".to_string()) })?; - let y_col = get_column_name(aesthetics, "pos2").ok_or_else(|| { + let y_col = get_quoted_column_name(aesthetics, "pos2").ok_or_else(|| { GgsqlError::ValidationError("Smooth requires 'pos2' aesthetic".to_string()) })?; @@ -245,13 +245,13 @@ fn stat_tls(query: &str, aesthetics: &Mappings, group_by: &[String]) -> Result Option }) } +/// Helper to extract a double-quoted column name for use in SQL expressions. +pub fn get_quoted_column_name(aesthetics: &Mappings, aesthetic: &str) -> Option { + get_column_name(aesthetics, aesthetic).map(|n| crate::naming::quote_ident(&n)) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/plot/scale/scale_type/binned.rs b/src/plot/scale/scale_type/binned.rs index 289ecc89..b491ca1f 100644 --- a/src/plot/scale/scale_type/binned.rs +++ b/src/plot/scale/scale_type/binned.rs @@ -8,6 +8,7 @@ use super::{ expand_numeric_range, resolve_common_steps, ScaleDataContext, ScaleTypeKind, ScaleTypeTrait, TransformKind, CLOSED_VALUES, OOB_CENSOR, OOB_SQUISH, OOB_VALUES_BINNED, }; +use crate::naming; use crate::plot::types::{ ArrayConstraint, DefaultParamValue, NumberConstraint, ParamConstraint, ParamDefinition, }; @@ -727,20 +728,21 @@ fn build_bin_condition( (if is_first { ">=" } else { ">" }, "<=") }; + let quoted = naming::quote_ident(column_name); if oob_squish && is_first && is_last { // Single bin with squish: capture everything "TRUE".to_string() } else if oob_squish && is_first { // First bin with squish: no lower bound, extends to -∞ - format!("{} {} {}", column_name, upper_op, upper_expr) + format!("{} {} {}", quoted, upper_op, upper_expr) } else if oob_squish && is_last { // Last bin with squish: no upper bound, extends to +∞ - format!("{} {} {}", column_name, lower_op, lower_expr) + format!("{} {} {}", quoted, lower_op, lower_expr) } else { // Normal bin with both bounds format!( "{} {} {} AND {} {} {}", - column_name, lower_op, lower_expr, column_name, upper_op, upper_expr + quoted, lower_op, lower_expr, quoted, upper_op, upper_expr ) } } @@ -855,10 +857,10 @@ mod tests { // Should produce CASE WHEN with bin centers 5, 15, 25 assert!(sql.contains("CASE")); - assert!(sql.contains("WHEN value >= 0 AND value < 10 THEN 5")); - assert!(sql.contains("WHEN value >= 10 AND value < 20 THEN 15")); + assert!(sql.contains("WHEN \"value\" >= 0 AND \"value\" < 10 THEN 5")); + assert!(sql.contains("WHEN \"value\" >= 10 AND \"value\" < 20 THEN 15")); // Last bin should be inclusive on both ends - assert!(sql.contains("WHEN value >= 20 AND value <= 30 THEN 25")); + assert!(sql.contains("WHEN \"value\" >= 20 AND \"value\" <= 30 THEN 25")); assert!(sql.contains("ELSE NULL END")); } @@ -906,8 +908,8 @@ mod tests { .unwrap(); // closed="left": [lower, upper) except last which is [lower, upper] - assert!(sql.contains("col >= 0 AND col < 10")); - assert!(sql.contains("col >= 10 AND col <= 20")); // last bin inclusive + assert!(sql.contains("\"col\" >= 0 AND \"col\" < 10")); + assert!(sql.contains("\"col\" >= 10 AND \"col\" <= 20")); // last bin inclusive } #[test] @@ -932,8 +934,8 @@ mod tests { .unwrap(); // closed="right": first bin is [lower, upper], rest are (lower, upper] - assert!(sql.contains("col >= 0 AND col <= 10")); // first bin inclusive - assert!(sql.contains("col > 10 AND col <= 20")); + assert!(sql.contains("\"col\" >= 0 AND \"col\" <= 10")); // first bin inclusive + assert!(sql.contains("\"col\" > 10 AND \"col\" <= 20")); } #[test] @@ -1191,8 +1193,8 @@ mod tests { sql ); assert!( - sql.contains("value >= 0"), - "SQL should use raw column name. Got: {}", + sql.contains("\"value\" >= 0"), + "SQL should use quoted column name. Got: {}", sql ); assert!( @@ -1227,7 +1229,10 @@ mod tests { !sql.contains("CAST("), "SQL should not contain CAST when column is numeric" ); - assert!(sql.contains("value >= 0"), "SQL should use raw column name"); + assert!( + sql.contains("\"value\" >= 0"), + "SQL should use quoted column name" + ); } #[test] @@ -1503,9 +1508,9 @@ mod tests { "left", vec![0.0, 10.0, 20.0, 30.0], vec![ - "WHEN value < 10 THEN 5", // First bin extends to -∞ - "WHEN value >= 10 AND value < 20 THEN 15", // Middle bin - "WHEN value >= 20 THEN 25", // Last bin extends to +∞ + "WHEN \"value\" < 10 THEN 5", // First bin extends to -∞ + "WHEN \"value\" >= 10 AND \"value\" < 20 THEN 15", // Middle bin + "WHEN \"value\" >= 20 THEN 25", // Last bin extends to +∞ ], ), // closed="right" with 3 bins (4 breaks) @@ -1513,9 +1518,9 @@ mod tests { "right", vec![0.0, 10.0, 20.0, 30.0], vec![ - "WHEN value <= 10 THEN 5", // First bin extends to -∞ - "WHEN value > 10 AND value <= 20 THEN 15", // Middle bin - "WHEN value > 20 THEN 25", // Last bin extends to +∞ + "WHEN \"value\" <= 10 THEN 5", // First bin extends to -∞ + "WHEN \"value\" > 10 AND \"value\" <= 20 THEN 15", // Middle bin + "WHEN \"value\" > 20 THEN 25", // Last bin extends to +∞ ], ), ]; @@ -1576,11 +1581,11 @@ mod tests { .pre_stat_transform_sql("x", &DataType::Float64, &scale, &AnsiDialect) .unwrap(); assert!( - sql.contains("WHEN x < 50 THEN 25"), + sql.contains("WHEN \"x\" < 50 THEN 25"), "Two bins: first should extend to -∞" ); assert!( - sql.contains("WHEN x >= 50 THEN 75"), + sql.contains("WHEN \"x\" >= 50 THEN 75"), "Two bins: last should extend to +∞" ); } @@ -1625,11 +1630,11 @@ mod tests { .pre_stat_transform_sql("x", &DataType::Float64, &scale, &AnsiDialect) .unwrap(); assert!( - sql.contains("x >= 0 AND x < 10"), + sql.contains("\"x\" >= 0 AND \"x\" < 10"), "First bin should have lower bound with censor" ); assert!( - sql.contains("x >= 10 AND x <= 20"), + sql.contains("\"x\" >= 10 AND \"x\" <= 20"), "Last bin should have upper bound with censor" ); } @@ -1642,14 +1647,17 @@ mod tests { ( true, vec![ - "WHEN col < 10 THEN 5", - "WHEN col >= 10 AND col < 20 THEN 15", - "WHEN col >= 20 THEN 25", + "WHEN \"col\" < 10 THEN 5", + "WHEN \"col\" >= 10 AND \"col\" < 20 THEN 15", + "WHEN \"col\" >= 20 THEN 25", ], ), ( false, - vec!["col >= 0 AND col < 10", "col >= 10 AND col <= 20"], + vec![ + "\"col\" >= 0 AND \"col\" < 10", + "\"col\" >= 10 AND \"col\" <= 20", + ], ), ]; diff --git a/src/plot/scale/scale_type/continuous.rs b/src/plot/scale/scale_type/continuous.rs index 92f2c397..8665b0ad 100644 --- a/src/plot/scale/scale_type/continuous.rs +++ b/src/plot/scale/scale_type/continuous.rs @@ -5,6 +5,7 @@ use polars::prelude::DataType; use super::{ ScaleTypeKind, ScaleTypeTrait, TransformKind, OOB_CENSOR, OOB_SQUISH, OOB_VALUES_CONTINUOUS, }; +use crate::naming; use crate::plot::types::{ ArrayConstraint, DefaultParamValue, NumberConstraint, ParamConstraint, ParamDefinition, }; @@ -214,14 +215,18 @@ impl ScaleTypeTrait for Continuous { .unwrap_or(super::default_oob(&scale.aesthetic)); match oob { - OOB_CENSOR => Some(format!( - "(CASE WHEN {} >= {} AND {} <= {} THEN {} ELSE NULL END)", - column_name, min, column_name, max, column_name - )), + OOB_CENSOR => { + let quoted = naming::quote_ident(column_name); + Some(format!( + "(CASE WHEN {} >= {} AND {} <= {} THEN {} ELSE NULL END)", + quoted, min, quoted, max, quoted + )) + } OOB_SQUISH => { let min_s = min.to_string(); let max_s = max.to_string(); - let inner = dialect.sql_least(&[&max_s, column_name]); + let quoted = naming::quote_ident(column_name); + let inner = dialect.sql_least(&[&max_s, "ed]); Some(dialect.sql_greatest(&[&min_s, &inner])) } _ => None, // "keep" = no transformation @@ -259,8 +264,8 @@ mod tests { let sql = sql.unwrap(); // Should generate CASE WHEN for censor assert!(sql.contains("CASE WHEN")); - assert!(sql.contains("value >= 0")); - assert!(sql.contains("value <= 100")); + assert!(sql.contains("\"value\" >= 0")); + assert!(sql.contains("\"value\" <= 100")); assert!(sql.contains("ELSE NULL")); } diff --git a/src/plot/scale/scale_type/discrete.rs b/src/plot/scale/scale_type/discrete.rs index ca31edfa..342daf9d 100644 --- a/src/plot/scale/scale_type/discrete.rs +++ b/src/plot/scale/scale_type/discrete.rs @@ -259,11 +259,12 @@ impl ScaleTypeTrait for Discrete { } // Always censor - discrete scales have no other valid OOB behavior + let quoted = format!("\"{}\"", column_name); Some(format!( "(CASE WHEN {} IN ({}) THEN {} ELSE NULL END)", - column_name, + quoted, allowed_values.join(", "), - column_name + quoted )) } } diff --git a/src/plot/scale/scale_type/mod.rs b/src/plot/scale/scale_type/mod.rs index baac8f4b..d54dfdae 100644 --- a/src/plot/scale/scale_type/mod.rs +++ b/src/plot/scale/scale_type/mod.rs @@ -3394,7 +3394,7 @@ mod tests { let dialect = AnsiDialect; assert_eq!( dialect.type_name_for(CastTargetType::Number), - Some("DOUBLE") + Some("DOUBLE PRECISION") ); assert_eq!( dialect.type_name_for(CastTargetType::Integer), diff --git a/src/plot/scale/scale_type/ordinal.rs b/src/plot/scale/scale_type/ordinal.rs index bc3c0d5e..50d1f472 100644 --- a/src/plot/scale/scale_type/ordinal.rs +++ b/src/plot/scale/scale_type/ordinal.rs @@ -291,11 +291,12 @@ impl ScaleTypeTrait for Ordinal { } // Always censor - ordinal scales have no other valid OOB behavior + let quoted = format!("\"{}\"", column_name); Some(format!( "(CASE WHEN {} IN ({}) THEN {} ELSE NULL END)", - column_name, + quoted, allowed_values.join(", "), - column_name + quoted )) } } diff --git a/src/reader/connection.rs b/src/reader/connection.rs index 63f90cf7..b97bd553 100644 --- a/src/reader/connection.rs +++ b/src/reader/connection.rs @@ -17,6 +17,9 @@ pub enum ConnectionInfo { /// SQLite file-based database #[allow(dead_code)] SQLite(String), + /// Generic ODBC connection (raw connection string after `odbc://` prefix) + #[allow(dead_code)] + ODBC(String), } /// Parse a connection string into connection information @@ -70,8 +73,17 @@ pub fn parse_connection_string(uri: &str) -> Result { return Ok(ConnectionInfo::SQLite(cleaned_path.to_string())); } + if let Some(conn_str) = uri.strip_prefix("odbc://") { + if conn_str.is_empty() { + return Err(GgsqlError::ReaderError( + "ODBC connection string cannot be empty".to_string(), + )); + } + return Ok(ConnectionInfo::ODBC(conn_str.to_string())); + } + Err(GgsqlError::ReaderError(format!( - "Unsupported connection string format: {}. Supported: duckdb://, postgres://, sqlite://", + "Unsupported connection string format: {}. Supported: duckdb://, postgres://, sqlite://, odbc://", uri ))) } @@ -133,6 +145,26 @@ mod tests { assert!(result.is_err()); } + #[test] + fn test_odbc() { + let info = parse_connection_string( + "odbc://Driver=Snowflake;Server=myaccount.snowflakecomputing.com", + ) + .unwrap(); + assert_eq!( + info, + ConnectionInfo::ODBC( + "Driver=Snowflake;Server=myaccount.snowflakecomputing.com".to_string() + ) + ); + } + + #[test] + fn test_odbc_empty() { + let result = parse_connection_string("odbc://"); + assert!(result.is_err()); + } + #[test] fn test_unsupported_scheme() { let result = parse_connection_string("mysql://localhost/db"); diff --git a/src/reader/data.rs b/src/reader/data.rs index 720ccac7..6e83b11a 100644 --- a/src/reader/data.rs +++ b/src/reader/data.rs @@ -185,7 +185,7 @@ pub fn rewrite_namespaced_sql(sql: &str) -> Result { replacements.push(( node.start_byte(), node.end_byte(), - naming::builtin_data_table(name), + format!("\"{}\"", naming::builtin_data_table(name)), )); } } @@ -315,7 +315,7 @@ mod tests { fn test_rewrite_namespaced_sql_simple() { let sql = "SELECT * FROM ggsql:penguins"; let rewritten = rewrite_namespaced_sql(sql).unwrap(); - assert_eq!(rewritten, "SELECT * FROM __ggsql_data_penguins__"); + assert_eq!(rewritten, "SELECT * FROM \"__ggsql_data_penguins__\""); } #[test] @@ -324,7 +324,7 @@ mod tests { let rewritten = rewrite_namespaced_sql(sql).unwrap(); assert_eq!( rewritten, - "SELECT * FROM __ggsql_data_penguins__ p, __ggsql_data_airquality__ a WHERE p.id = a.id" + "SELECT * FROM \"__ggsql_data_penguins__\" p, \"__ggsql_data_airquality__\" a WHERE p.id = a.id" ); } @@ -339,7 +339,7 @@ mod tests { fn test_rewrite_namespaced_sql_with_visualise() { let sql = "SELECT * FROM ggsql:penguins VISUALISE DRAW point MAPPING bill_len AS x, bill_dep AS y"; let rewritten = rewrite_namespaced_sql(sql).unwrap(); - assert!(rewritten.starts_with("SELECT * FROM __ggsql_data_penguins__")); + assert!(rewritten.starts_with("SELECT * FROM \"__ggsql_data_penguins__\"")); assert!(!rewritten.contains("ggsql:")); } } diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index c35dc05f..beb8e7f1 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -36,7 +36,7 @@ impl super::SqlDialect for DuckDbDialect { fn sql_generate_series(&self, n: usize) -> String { format!( - "__ggsql_seq__(n) AS (SELECT generate_series FROM GENERATE_SERIES(0, {}))", + "\"__ggsql_seq__\"(n) AS (SELECT generate_series FROM GENERATE_SERIES(0, {}))", n - 1 ) } @@ -44,14 +44,19 @@ impl super::SqlDialect for DuckDbDialect { fn sql_percentile(&self, column: &str, fraction: f64, from: &str, groups: &[String]) -> String { let group_filter = groups .iter() - .map(|g| format!("AND __ggsql_pct__.{g} IS NOT DISTINCT FROM __ggsql_qt__.{g}")) + .map(|g| { + let q = crate::naming::quote_ident(g); + format!("AND \"__ggsql_pct__\".{q} IS NOT DISTINCT FROM \"__ggsql_qt__\".{q}") + }) .collect::>() .join(" "); + let quoted_column = crate::naming::quote_ident(column); format!( "(SELECT QUANTILE_CONT({column}, {fraction}) \ - FROM ({from}) AS __ggsql_pct__ \ - WHERE {column} IS NOT NULL {group_filter})" + FROM ({from}) AS \"__ggsql_pct__\" \ + WHERE {column} IS NOT NULL {group_filter})", + column = quoted_column ) } } @@ -144,34 +149,7 @@ impl DuckDBReader { } } -/// Validate a table name -fn validate_table_name(name: &str) -> Result<()> { - if name.is_empty() { - return Err(GgsqlError::ReaderError("Table name cannot be empty".into())); - } - - // Reject characters that could break double-quoted identifiers or cause issues - let forbidden = ['"', '\0', '\n', '\r']; - for ch in forbidden { - if name.contains(ch) { - return Err(GgsqlError::ReaderError(format!( - "Table name '{}' contains invalid character '{}'", - name, - ch.escape_default() - ))); - } - } - - // Reasonable length limit - if name.len() > 128 { - return Err(GgsqlError::ReaderError(format!( - "Table name '{}' exceeds maximum length of 128 characters", - name - ))); - } - - Ok(()) -} +use super::validate_table_name; /// Convert a Polars DataFrame to DuckDB Arrow query parameters via IPC serialization fn dataframe_to_arrow_params(df: DataFrame) -> Result<[usize; 2]> { @@ -639,6 +617,10 @@ impl Reader for DuckDBReader { Ok(()) } + fn execute(&self, query: &str) -> Result { + super::execute_with_reader(self, query) + } + fn dialect(&self) -> &dyn super::SqlDialect { &DuckDbDialect } diff --git a/src/reader/mod.rs b/src/reader/mod.rs index cc27b392..c9b03464 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -46,9 +46,9 @@ use crate::{DataFrame, GgsqlError, Result}; /// /// Default implementations produce portable ANSI SQL. pub trait SqlDialect { - /// SQL type name for numeric columns (e.g., "DOUBLE") + /// SQL type name for numeric columns (e.g., "DOUBLE PRECISION") fn number_type_name(&self) -> Option<&str> { - Some("DOUBLE") + Some("DOUBLE PRECISION") } /// SQL type name for integer columns (e.g., "BIGINT") @@ -94,6 +94,48 @@ pub trait SqlDialect { } } + // ========================================================================= + // Schema introspection queries (for Connections pane) + // ========================================================================= + + /// SQL to list catalog names. Returns rows with column `catalog_name`. + fn sql_list_catalogs(&self) -> String { + "SELECT DISTINCT catalog_name FROM information_schema.schemata ORDER BY catalog_name".into() + } + + /// SQL to list schema names within a catalog. Returns rows with column `schema_name`. + fn sql_list_schemas(&self, catalog: &str) -> String { + format!( + "SELECT DISTINCT schema_name FROM information_schema.schemata \ + WHERE catalog_name = '{}' ORDER BY schema_name", + catalog.replace('\'', "''") + ) + } + + /// SQL to list tables/views within a catalog and schema. + /// Returns rows with columns `table_name` and `table_type`. + fn sql_list_tables(&self, catalog: &str, schema: &str) -> String { + format!( + "SELECT DISTINCT table_name, table_type FROM information_schema.tables \ + WHERE table_catalog = '{}' AND table_schema = '{}' ORDER BY table_name", + catalog.replace('\'', "''"), + schema.replace('\'', "''") + ) + } + + /// SQL to list columns in a table. + /// Returns rows with columns `column_name` and `data_type`. + fn sql_list_columns(&self, catalog: &str, schema: &str, table: &str) -> String { + format!( + "SELECT column_name, data_type FROM information_schema.columns \ + WHERE table_catalog = '{}' AND table_schema = '{}' AND table_name = '{}' \ + ORDER BY ordinal_position", + catalog.replace('\'', "''"), + schema.replace('\'', "''"), + table.replace('\'', "''") + ) + } + /// Scalar MAX across any number of SQL expressions. fn sql_greatest(&self, exprs: &[&str]) -> String { let mut result = exprs[0].to_string(); @@ -124,12 +166,12 @@ pub trait SqlDialect { let base_sq = base_size * base_size; let base_max = base_size - 1; format!( - "__ggsql_base__(n) AS (\ - SELECT 0 UNION ALL SELECT n + 1 FROM __ggsql_base__ WHERE n < {base_max}\ + "\"__ggsql_base__\"(n) AS (\ + SELECT 0 UNION ALL SELECT n + 1 FROM \"__ggsql_base__\" WHERE n < {base_max}\ ),\ - __ggsql_seq__(n) AS (\ + \"__ggsql_seq__\"(n) AS (\ SELECT CAST(a.n * {base_sq} + b.n * {base_size} + c.n AS REAL) AS n \ - FROM __ggsql_base__ a, __ggsql_base__ b, __ggsql_base__ c \ + FROM \"__ggsql_base__\" a, \"__ggsql_base__\" b, \"__ggsql_base__\" c \ WHERE a.n * {base_sq} + b.n * {base_size} + c.n < {n}\ )" ) @@ -143,12 +185,16 @@ pub trait SqlDialect { // Uses NTILE(4) to divide data into quartiles, then interpolates between boundaries. let group_filter = groups .iter() - .map(|g| format!("AND __ggsql_pct__.{g} IS NOT DISTINCT FROM __ggsql_qt__.{g}")) + .map(|g| { + let q = crate::naming::quote_ident(g); + format!("AND \"__ggsql_pct__\".{q} IS NOT DISTINCT FROM \"__ggsql_qt__\".{q}") + }) .collect::>() .join(" "); let lo_tile = (fraction * 4.0).ceil() as usize; let hi_tile = lo_tile + 1; + let quoted_column = crate::naming::quote_ident(column); format!( "(SELECT (\ @@ -158,9 +204,10 @@ pub trait SqlDialect { FROM (\ SELECT {column} AS __val, \ NTILE(4) OVER (ORDER BY {column}) AS __tile \ - FROM ({from}) AS __ggsql_pct__ \ + FROM ({from}) AS \"__ggsql_pct__\" \ WHERE {column} IS NOT NULL {group_filter}\ - ))" + ))", + column = quoted_column ) } @@ -209,6 +256,12 @@ pub mod duckdb; #[cfg(feature = "sqlite")] pub mod sqlite; +#[cfg(feature = "odbc")] +pub mod odbc; + +#[cfg(feature = "odbc")] +pub mod snowflake; + pub mod connection; pub mod data; mod spec; @@ -219,6 +272,45 @@ pub use duckdb::DuckDBReader; #[cfg(feature = "sqlite")] pub use sqlite::SqliteReader; +#[cfg(feature = "odbc")] +pub use odbc::OdbcReader; + +// ============================================================================ +// Shared utilities +// ============================================================================ + +/// Validate a table name for use in SQL statements. +/// +/// Rejects empty names, names with characters that could break double-quoted +/// identifiers, and names exceeding 128 characters. +pub(crate) fn validate_table_name(name: &str) -> Result<()> { + if name.is_empty() { + return Err(GgsqlError::ReaderError("Table name cannot be empty".into())); + } + + // Reject characters that could break double-quoted identifiers or cause issues + let forbidden = ['"', '\0', '\n', '\r']; + for ch in forbidden { + if name.contains(ch) { + return Err(GgsqlError::ReaderError(format!( + "Table name '{}' contains invalid character '{}'", + name, + ch.escape_default() + ))); + } + } + + // Reasonable length limit + if name.len() > 128 { + return Err(GgsqlError::ReaderError(format!( + "Table name '{}' exceeds maximum length of 128 characters", + name + ))); + } + + Ok(()) +} + // ============================================================================ // Spec - Result of reader.execute() // ============================================================================ @@ -363,37 +455,7 @@ pub trait Reader { /// let writer = VegaLiteWriter::new(); /// let json = writer.render(&spec)?; /// ``` - fn execute(&self, query: &str) -> Result - where - Self: Sized, - { - // Run validation first to capture warnings - let validated = validate(query)?; - let warnings: Vec = validated.warnings().to_vec(); - - // Prepare data with type names for this reader - let prepared_data = prepare_data_with_reader(query, self)?; - - // Get the first (and typically only) spec - let plot = prepared_data.specs.into_iter().next().ok_or_else(|| { - GgsqlError::ValidationError("No visualization spec found".to_string()) - })?; - - // For now, layer_sql and stat_sql are not tracked in PreparedData - // (they were part of main's version but not HEAD's) - let layer_sql = vec![None; plot.layers.len()]; - let stat_sql = vec![None; plot.layers.len()]; - - Ok(Spec::new( - plot, - prepared_data.data, - prepared_data.sql, - prepared_data.visual, - layer_sql, - stat_sql, - warnings, - )) - } + fn execute(&self, query: &str) -> Result; /// Get the SQL dialect for this reader. /// @@ -403,6 +465,36 @@ pub trait Reader { } } +/// Execute a ggsql query using any reader (object-safe entry point). +/// +/// This is the shared implementation behind `Reader::execute()`. Concrete +/// readers delegate to this so the trait stays object-safe (no `Self: Sized` +/// bound on `execute`). +pub fn execute_with_reader(reader: &dyn Reader, query: &str) -> Result { + let validated = validate(query)?; + let warnings: Vec = validated.warnings().to_vec(); + + let prepared_data = prepare_data_with_reader(query, reader)?; + + let plot = + prepared_data.specs.into_iter().next().ok_or_else(|| { + GgsqlError::ValidationError("No visualization spec found".to_string()) + })?; + + let layer_sql = vec![None; plot.layers.len()]; + let stat_sql = vec![None; plot.layers.len()]; + + Ok(Spec::new( + plot, + prepared_data.data, + prepared_data.sql, + prepared_data.visual, + layer_sql, + stat_sql, + warnings, + )) +} + #[cfg(test)] #[cfg(all(feature = "duckdb", feature = "vegalite"))] mod tests { diff --git a/src/reader/odbc.rs b/src/reader/odbc.rs new file mode 100644 index 00000000..d5b6339d --- /dev/null +++ b/src/reader/odbc.rs @@ -0,0 +1,678 @@ +//! Generic ODBC data source implementation +//! +//! Provides a reader for any ODBC-compatible database (Snowflake, PostgreSQL, +//! SQL Server, etc.) using the `odbc-api` crate. + +use crate::reader::Reader; +use crate::{DataFrame, GgsqlError, Result}; +use odbc_api::{buffers::TextRowSet, ConnectionOptions, Cursor, Environment}; +use polars::prelude::*; +use std::cell::RefCell; +use std::collections::HashSet; +use std::sync::OnceLock; + +/// Global ODBC environment (must be a singleton per process). +fn odbc_env() -> &'static Environment { + static ENV: OnceLock = OnceLock::new(); + ENV.get_or_init(|| Environment::new().expect("Failed to create ODBC environment")) +} + +/// Detect the backend SQL dialect from an ODBC connection string. +/// +/// Returns a dialect matching the detected backend (e.g. Snowflake, SQLite, +/// DuckDB, or ANSI for generic/unknown backends). +fn detect_dialect(conn_str: &str) -> Box { + let lower = conn_str.to_lowercase(); + if lower.contains("driver=snowflake") { + Box::new(super::snowflake::SnowflakeDialect) + } else if lower.contains("driver=sqlite") || lower.contains("driver={sqlite") { + #[cfg(feature = "sqlite")] + { + Box::new(super::sqlite::SqliteDialect) + } + #[cfg(not(feature = "sqlite"))] + { + Box::new(super::AnsiDialect) + } + } else if lower.contains("driver=duckdb") || lower.contains("driver={duckdb") { + #[cfg(feature = "duckdb")] + { + Box::new(super::duckdb::DuckDbDialect) + } + #[cfg(not(feature = "duckdb"))] + { + Box::new(super::AnsiDialect) + } + } else { + Box::new(super::AnsiDialect) + } +} + +/// Generic ODBC reader implementing the `Reader` trait. +pub struct OdbcReader { + connection: odbc_api::Connection<'static>, + dialect: Box, + registered_tables: RefCell>, +} + +// Safety: odbc_api::Connection is Send when we ensure single-threaded access. +// The Reader trait requires &self (immutable) for execute_sql, and ODBC +// connections are safe to use from one thread at a time. +unsafe impl Send for OdbcReader {} + +impl OdbcReader { + /// Create a new ODBC reader from a `odbc://` connection URI. + /// + /// The URI format is `odbc://` followed by the raw ODBC connection string. + /// For Snowflake with Posit Workbench credentials, the reader will + /// automatically detect and inject OAuth tokens. + pub fn from_connection_string(uri: &str) -> Result { + let conn_str = uri + .strip_prefix("odbc://") + .ok_or_else(|| GgsqlError::ReaderError("ODBC URI must start with odbc://".into()))?; + + let mut conn_str = conn_str.to_string(); + + // Snowflake ConnectionName resolution from connections.toml + if is_snowflake(&conn_str) { + if let Some(resolved) = resolve_connection_name(&conn_str) { + conn_str = resolved; + } + } + + // Snowflake Workbench credential detection + if is_snowflake(&conn_str) && !has_token(&conn_str) { + if let Some(token) = detect_workbench_token() { + conn_str = inject_snowflake_token(&conn_str, &token); + } + } + + // Detect backend dialect from connection string + let dialect = detect_dialect(&conn_str); + + let env = odbc_env(); + let connection = env + .connect_with_connection_string(&conn_str, ConnectionOptions::default()) + .map_err(|e| GgsqlError::ReaderError(format!("ODBC connection failed: {}", e)))?; + + Ok(Self { + connection, + dialect, + registered_tables: RefCell::new(HashSet::new()), + }) + } +} + +impl Reader for OdbcReader { + fn execute_sql(&self, sql: &str) -> Result { + // Execute the query (3rd arg = query timeout, None = no timeout) + let cursor = self + .connection + .execute(sql, (), None) + .map_err(|e| GgsqlError::ReaderError(format!("ODBC execute failed: {}", e)))?; + + let Some(cursor) = cursor else { + // DDL or non-query statement — return empty DataFrame + return DataFrame::new(Vec::::new()) + .map_err(|e| GgsqlError::ReaderError(format!("Empty DataFrame error: {}", e))); + }; + + cursor_to_dataframe(cursor) + } + + fn register(&self, name: &str, df: DataFrame, replace: bool) -> Result<()> { + super::validate_table_name(name)?; + + if replace { + let drop_sql = format!("DROP TABLE IF EXISTS \"{}\"", name); + // Ignore errors from DROP — table may not exist + let _ = self.connection.execute(&drop_sql, (), None); + } + + // Build CREATE TEMP TABLE with typed columns + let schema = df.schema(); + let col_defs: Vec = schema + .iter() + .map(|(col_name, dtype)| format!("\"{}\" {}", col_name, polars_dtype_to_sql(dtype))) + .collect(); + let create_sql = format!( + "CREATE TEMPORARY TABLE \"{}\" ({})", + name, + col_defs.join(", ") + ); + self.connection + .execute(&create_sql, (), None) + .map_err(|e| { + GgsqlError::ReaderError(format!("Failed to create temp table '{}': {}", name, e)) + })?; + + // Insert data using ODBC bulk text inserter + let num_rows = df.height(); + if num_rows > 0 { + let num_cols = df.width(); + let placeholders: Vec<&str> = vec!["?"; num_cols]; + let insert_sql = format!( + "INSERT INTO \"{}\" VALUES ({})", + name, + placeholders.join(", ") + ); + + // Convert all columns to string representation for text insertion + let string_columns: Vec>> = df + .get_columns() + .iter() + .map(|col| { + (0..num_rows) + .map(|row| { + let val = col.get(row).ok()?; + if val == AnyValue::Null { + None + } else { + Some(format!("{}", val)) + } + }) + .collect() + }) + .collect(); + + // Determine max string length per column for buffer allocation + let max_str_lens: Vec = string_columns + .iter() + .map(|col| { + col.iter() + .filter_map(|v| v.as_ref().map(|s| s.len())) + .max() + .unwrap_or(1) + .max(1) // minimum buffer size of 1 + }) + .collect(); + + const BATCH_SIZE: usize = 1024; + let prepared = self.connection.prepare(&insert_sql).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to prepare INSERT for '{}': {}", name, e)) + })?; + + let batch_capacity = num_rows.min(BATCH_SIZE); + let mut inserter = prepared + .into_text_inserter(batch_capacity, max_str_lens) + .map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to create bulk inserter for '{}': {}", + name, e + )) + })?; + + let mut rows_in_batch = 0; + for row_idx in 0..num_rows { + let row_values: Vec> = string_columns + .iter() + .map(|col| col[row_idx].as_ref().map(|s| s.as_bytes())) + .collect(); + + inserter.append(row_values.into_iter()).map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to append row {} to '{}': {}", + row_idx, name, e + )) + })?; + rows_in_batch += 1; + + if rows_in_batch >= BATCH_SIZE { + inserter.execute().map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to execute batch insert into '{}': {}", + name, e + )) + })?; + inserter.clear(); + rows_in_batch = 0; + } + } + + // Execute final partial batch + if rows_in_batch > 0 { + inserter.execute().map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to execute final batch insert into '{}': {}", + name, e + )) + })?; + } + } + + self.registered_tables.borrow_mut().insert(name.to_string()); + Ok(()) + } + + fn unregister(&self, name: &str) -> Result<()> { + if !self.registered_tables.borrow().contains(name) { + return Err(GgsqlError::ReaderError(format!( + "Table '{}' was not registered via this reader", + name + ))); + } + + let sql = format!("DROP TABLE IF EXISTS \"{}\"", name); + self.connection.execute(&sql, (), None).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to unregister table '{}': {}", name, e)) + })?; + + self.registered_tables.borrow_mut().remove(name); + Ok(()) + } + + fn execute(&self, query: &str) -> Result { + super::execute_with_reader(self, query) + } + + fn dialect(&self) -> &dyn super::SqlDialect { + &*self.dialect + } +} + +/// Map a Polars data type to a SQL column type string. +fn polars_dtype_to_sql(dtype: &DataType) -> &'static str { + match dtype { + DataType::Boolean => "BOOLEAN", + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => "BIGINT", + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => "BIGINT", + DataType::Float32 | DataType::Float64 => "DOUBLE PRECISION", + DataType::Date => "DATE", + DataType::Datetime(_, _) => "TIMESTAMP", + DataType::Time => "TIME", + _ => "TEXT", + } +} + +/// Convert an ODBC cursor to a Polars DataFrame by fetching all rows as text. +fn cursor_to_dataframe(mut cursor: impl Cursor) -> Result { + let col_count = cursor + .num_result_cols() + .map_err(|e| GgsqlError::ReaderError(format!("Failed to get column count: {}", e)))? + as usize; + + if col_count == 0 { + return DataFrame::new(Vec::::new()) + .map_err(|e| GgsqlError::ReaderError(e.to_string())); + } + + // Collect column names + let mut col_names = Vec::with_capacity(col_count); + for i in 1..=col_count as u16 { + let name = cursor.col_name(i).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to get column {} name: {}", i, e)) + })?; + col_names.push(name); + } + + // Fetch all rows as text into column-oriented vectors + let batch_size = 1000; + let max_str_len = 4096; + let mut columns: Vec>> = vec![Vec::new(); col_count]; + + let mut row_set = TextRowSet::for_cursor(batch_size, &mut cursor, Some(max_str_len)) + .map_err(|e| GgsqlError::ReaderError(format!("Failed to create row set: {}", e)))?; + + let mut block_cursor = cursor + .bind_buffer(&mut row_set) + .map_err(|e| GgsqlError::ReaderError(format!("Failed to bind buffer: {}", e)))?; + + while let Some(batch) = block_cursor + .fetch() + .map_err(|e| GgsqlError::ReaderError(format!("Failed to fetch batch: {}", e)))? + { + let num_rows = batch.num_rows(); + for (col_idx, column) in columns.iter_mut().enumerate() { + for row_idx in 0..num_rows { + let value = batch + .at_as_str(col_idx, row_idx) + .ok() + .flatten() + .map(|s| s.to_string()); + column.push(value); + } + } + } + + // Build Polars Series from the text data, attempting type inference + let series: Vec = col_names + .iter() + .zip(columns.iter()) + .map(|(name, values)| { + // Try to parse as numeric first, then fall back to string + let series = if let Some(int_series) = try_parse_integers(name, values) { + int_series + } else if let Some(float_series) = try_parse_floats(name, values) { + float_series + } else { + // Fall back to string + Series::new( + name.into(), + values + .iter() + .map(|v| v.as_deref()) + .collect::>>(), + ) + }; + Column::from(series) + }) + .collect(); + + DataFrame::new(series).map_err(|e| GgsqlError::ReaderError(e.to_string())) +} + +/// Try to parse all non-null values as i64. +fn try_parse_integers(name: &str, values: &[Option]) -> Option { + let parsed: Vec> = values + .iter() + .map(|v| match v { + None => Some(None), + Some(s) => s.parse::().ok().map(Some), + }) + .collect::>>()?; + Some(Series::new(name.into(), parsed)) +} + +/// Try to parse all non-null values as f64. +fn try_parse_floats(name: &str, values: &[Option]) -> Option { + let parsed: Vec> = values + .iter() + .map(|v| match v { + None => Some(None), + Some(s) => s.parse::().ok().map(Some), + }) + .collect::>>()?; + Some(Series::new(name.into(), parsed)) +} + +// ============================================================================ +// Snowflake Workbench credential detection +// ============================================================================ + +fn is_snowflake(conn_str: &str) -> bool { + conn_str.to_lowercase().contains("driver=snowflake") +} + +fn has_token(conn_str: &str) -> bool { + conn_str.to_lowercase().contains("token=") +} + +fn home_dir() -> Option { + #[cfg(target_os = "windows")] + { + std::env::var("USERPROFILE") + .ok() + .map(std::path::PathBuf::from) + } + #[cfg(not(target_os = "windows"))] + { + std::env::var("HOME").ok().map(std::path::PathBuf::from) + } +} + +/// Find the Snowflake connections.toml file, checking standard locations. +fn find_snowflake_connections_toml() -> Option { + use std::path::PathBuf; + + // 1. $SNOWFLAKE_HOME/connections.toml + if let Ok(snowflake_home) = std::env::var("SNOWFLAKE_HOME") { + let p = PathBuf::from(&snowflake_home).join("connections.toml"); + if p.exists() { + return Some(p); + } + } + + // 2. ~/.snowflake/connections.toml + if let Some(home) = home_dir() { + let p = home.join(".snowflake").join("connections.toml"); + if p.exists() { + return Some(p); + } + } + + // 3. Platform-specific paths + if let Some(home) = home_dir() { + #[cfg(target_os = "macos")] + { + let p = home.join("Library/Application Support/snowflake/connections.toml"); + if p.exists() { + return Some(p); + } + } + + #[cfg(target_os = "linux")] + { + let xdg = std::env::var("XDG_CONFIG_HOME") + .map(PathBuf::from) + .unwrap_or_else(|_| home.join(".config")); + let p = xdg.join("snowflake").join("connections.toml"); + if p.exists() { + return Some(p); + } + } + + #[cfg(target_os = "windows")] + { + let p = home.join("AppData/Local/snowflake/connections.toml"); + if p.exists() { + return Some(p); + } + } + } + + None +} + +/// Resolve a `ConnectionName=` parameter in a Snowflake ODBC connection +/// string by reading the named entry from `~/.snowflake/connections.toml` and +/// building a full ODBC connection string from it. +fn resolve_connection_name(conn_str: &str) -> Option { + // Extract ConnectionName value (case-insensitive) + let lower = conn_str.to_lowercase(); + let cn_key = "connectionname="; + let cn_start = lower.find(cn_key)?; + let value_start = cn_start + cn_key.len(); + + let rest = &conn_str[value_start..]; + let value_end = rest.find(';').unwrap_or(rest.len()); + let connection_name = rest[..value_end].trim(); + + if connection_name.is_empty() { + return None; + } + + // Read and parse connections.toml + let toml_path = find_snowflake_connections_toml()?; + let content = std::fs::read_to_string(&toml_path).ok()?; + let doc = content.parse::().ok()?; + + let entry = doc.get(connection_name)?; + if !entry.is_table() && !entry.is_inline_table() { + return None; + } + + // Build ODBC connection string from TOML entry fields + let get_str = |key: &str| -> Option { entry.get(key)?.as_str().map(|s| s.to_string()) }; + + let account = get_str("account")?; + let mut parts = vec![ + "Driver=Snowflake".to_string(), + format!("Server={}.snowflakecomputing.com", account), + ]; + + if let Some(user) = get_str("user") { + parts.push(format!("UID={}", user)); + } + if let Some(password) = get_str("password") { + parts.push(format!("PWD={}", password)); + } + if let Some(authenticator) = get_str("authenticator") { + parts.push(format!("Authenticator={}", authenticator)); + } + if let Some(token) = get_str("token") { + parts.push(format!("Token={}", token)); + } + if let Some(warehouse) = get_str("warehouse") { + parts.push(format!("Warehouse={}", warehouse)); + } + if let Some(database) = get_str("database") { + parts.push(format!("Database={}", database)); + } + if let Some(schema) = get_str("schema") { + parts.push(format!("Schema={}", schema)); + } + if let Some(role) = get_str("role") { + parts.push(format!("Role={}", role)); + } + + Some(parts.join(";")) +} + +/// Detect Posit Workbench Snowflake OAuth token. +/// +/// Checks `SNOWFLAKE_HOME` for a Workbench-managed `connections.toml` file +/// containing OAuth credentials. +fn detect_workbench_token() -> Option { + let snowflake_home = std::env::var("SNOWFLAKE_HOME").ok()?; + + // Only use Workbench credentials if the path indicates Workbench management + if !snowflake_home.contains("posit-workbench") { + return None; + } + + let toml_path = std::path::Path::new(&snowflake_home).join("connections.toml"); + let content = std::fs::read_to_string(&toml_path).ok()?; + + let doc = content.parse::().ok()?; + let token = doc.get("workbench")?.get("token")?.as_str()?.to_string(); + + if token.is_empty() { + None + } else { + Some(token) + } +} + +/// Inject OAuth token into a Snowflake ODBC connection string. +fn inject_snowflake_token(conn_str: &str, token: &str) -> String { + // Append authenticator and token parameters + let mut result = conn_str.trim_end_matches(';').to_string(); + result.push_str(";Authenticator=oauth;Token="); + result.push_str(token); + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_snowflake() { + assert!(is_snowflake( + "Driver=Snowflake;Server=foo.snowflakecomputing.com" + )); + assert!(!is_snowflake("Driver={PostgreSQL};Server=localhost")); + } + + #[test] + fn test_has_token() { + assert!(has_token("Driver=Snowflake;Token=abc123")); + assert!(!has_token("Driver=Snowflake;Server=foo")); + } + + #[test] + fn test_detect_dialect() { + // Snowflake uses SHOW commands + let dialect = detect_dialect("Driver=Snowflake;Server=foo"); + assert!(dialect.sql_list_catalogs().contains("SHOW")); + + // PostgreSQL uses information_schema (ANSI default) + let dialect = detect_dialect("Driver={PostgreSQL};Server=localhost"); + assert!(dialect.sql_list_catalogs().contains("information_schema")); + + // Generic uses information_schema (ANSI default) + let dialect = detect_dialect("Driver=SomeOther;Server=localhost"); + assert!(dialect.sql_list_catalogs().contains("information_schema")); + } + + #[test] + fn test_inject_snowflake_token() { + let result = inject_snowflake_token( + "Driver=Snowflake;Server=foo.snowflakecomputing.com", + "mytoken", + ); + assert!(result.contains("Authenticator=oauth")); + assert!(result.contains("Token=mytoken")); + } + + #[test] + fn test_resolve_connection_name_with_toml() { + use std::io::Write; + + // Create a temp dir with a connections.toml + let dir = tempfile::tempdir().unwrap(); + let toml_path = dir.path().join("connections.toml"); + let mut f = std::fs::File::create(&toml_path).unwrap(); + writeln!( + f, + r#" +default_connection_name = "myconn" + +[myconn] +account = "myaccount" +user = "myuser" +password = "mypass" +warehouse = "mywh" +database = "mydb" +schema = "public" +role = "myrole" + +[other] +account = "otheraccount" +"# + ) + .unwrap(); + + // Point SNOWFLAKE_HOME at our temp dir + std::env::set_var("SNOWFLAKE_HOME", dir.path()); + + let result = resolve_connection_name("Driver=Snowflake;ConnectionName=myconn"); + assert!(result.is_some()); + let conn = result.unwrap(); + assert!(conn.contains("Driver=Snowflake")); + assert!(conn.contains("Server=myaccount.snowflakecomputing.com")); + assert!(conn.contains("UID=myuser")); + assert!(conn.contains("PWD=mypass")); + assert!(conn.contains("Warehouse=mywh")); + assert!(conn.contains("Database=mydb")); + assert!(conn.contains("Schema=public")); + assert!(conn.contains("Role=myrole")); + + // Test with a connection that has fewer fields + let result2 = resolve_connection_name("Driver=Snowflake;ConnectionName=other"); + assert!(result2.is_some()); + let conn2 = result2.unwrap(); + assert!(conn2.contains("Server=otheraccount.snowflakecomputing.com")); + assert!(!conn2.contains("UID=")); + + // Test with non-existent connection name + let result3 = resolve_connection_name("Driver=Snowflake;ConnectionName=nonexistent"); + assert!(result3.is_none()); + + // No ConnectionName param → None + let result4 = resolve_connection_name("Driver=Snowflake;Server=foo"); + assert!(result4.is_none()); + + // Clean up env + std::env::remove_var("SNOWFLAKE_HOME"); + } + + #[test] + fn test_polars_dtype_to_sql() { + assert_eq!(polars_dtype_to_sql(&DataType::Int64), "BIGINT"); + assert_eq!(polars_dtype_to_sql(&DataType::Float64), "DOUBLE PRECISION"); + assert_eq!(polars_dtype_to_sql(&DataType::Boolean), "BOOLEAN"); + assert_eq!(polars_dtype_to_sql(&DataType::Date), "DATE"); + assert_eq!(polars_dtype_to_sql(&DataType::String), "TEXT"); + } +} diff --git a/src/reader/snowflake.rs b/src/reader/snowflake.rs new file mode 100644 index 00000000..33d5dd9e --- /dev/null +++ b/src/reader/snowflake.rs @@ -0,0 +1,30 @@ +//! Snowflake-specific SQL dialect. +//! +//! Overrides schema introspection to use Snowflake's SHOW commands +//! instead of information_schema queries. + +pub struct SnowflakeDialect; + +impl super::SqlDialect for SnowflakeDialect { + fn sql_list_catalogs(&self) -> String { + "SHOW DATABASES".into() + } + + fn sql_list_schemas(&self, catalog: &str) -> String { + let catalog_ident = catalog.replace('"', "\"\""); + format!("SHOW SCHEMAS IN DATABASE \"{catalog_ident}\"") + } + + fn sql_list_tables(&self, catalog: &str, schema: &str) -> String { + let catalog_ident = catalog.replace('"', "\"\""); + let schema_ident = schema.replace('"', "\"\""); + format!("SHOW OBJECTS IN SCHEMA \"{catalog_ident}\".\"{schema_ident}\"") + } + + fn sql_list_columns(&self, catalog: &str, schema: &str, table: &str) -> String { + let catalog_ident = catalog.replace('"', "\"\""); + let schema_ident = schema.replace('"', "\"\""); + let table_ident = table.replace('"', "\"\""); + format!("SHOW COLUMNS IN TABLE \"{catalog_ident}\".\"{schema_ident}\".\"{table_ident}\"") + } +} diff --git a/src/reader/sqlite.rs b/src/reader/sqlite.rs index 793ec928..b78e6a98 100644 --- a/src/reader/sqlite.rs +++ b/src/reader/sqlite.rs @@ -67,6 +67,29 @@ impl super::SqlDialect for SqliteDialect { "0".to_string() } } + + fn sql_list_catalogs(&self) -> String { + "SELECT name AS catalog_name FROM pragma_database_list ORDER BY name".into() + } + + fn sql_list_schemas(&self, _catalog: &str) -> String { + "SELECT 'main' AS schema_name".into() + } + + fn sql_list_tables(&self, catalog: &str, _schema: &str) -> String { + format!( + "SELECT name AS table_name, type AS table_type FROM \"{}\".sqlite_master \ + WHERE type IN ('table', 'view') ORDER BY name", + catalog.replace('"', "\"\"") + ) + } + + fn sql_list_columns(&self, _catalog: &str, _schema: &str, table: &str) -> String { + format!( + "SELECT name AS column_name, type AS data_type FROM pragma_table_info('{}') ORDER BY cid", + table.replace('\'', "''") + ) + } } /// SQLite database reader @@ -442,6 +465,10 @@ impl Reader for SqliteReader { Ok(()) } + fn execute(&self, query: &str) -> Result { + super::execute_with_reader(self, query) + } + fn dialect(&self) -> &dyn super::SqlDialect { &SqliteDialect }