Skip to content

Commit b3d6b49

Browse files
committed
Add SQL dialect function overriding for duckdb
1 parent 1f2795c commit b3d6b49

23 files changed

Lines changed: 547 additions & 529 deletions

File tree

ggsql-python/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ struct PyReaderBridge {
127127
obj: Py<PyAny>,
128128
}
129129

130+
static ANSI_DIALECT: ggsql::reader::AnsiDialect = ggsql::reader::AnsiDialect;
131+
130132
impl Reader for PyReaderBridge {
131133
fn execute_sql(&self, sql: &str) -> ggsql::Result<DataFrame> {
132134
Python::attach(|py| {
@@ -161,6 +163,10 @@ impl Reader for PyReaderBridge {
161163
Ok(())
162164
})
163165
}
166+
167+
fn dialect(&self) -> &dyn ggsql::reader::SqlDialect {
168+
&ANSI_DIALECT
169+
}
164170
}
165171

166172
// ============================================================================

src/execute/casting.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
//! scale requirements and updating type info accordingly.
55
66
use crate::plot::scale::coerce_dtypes;
7-
use crate::plot::{CastTargetType, Layer, ParameterValue, Plot, SqlTypeNames};
7+
use crate::plot::{CastTargetType, Layer, ParameterValue, Plot};
8+
use crate::reader::SqlDialect;
89
use crate::{naming, DataSource};
910
use polars::prelude::{DataType, TimeUnit};
1011
use std::collections::{HashMap, HashSet};
@@ -57,7 +58,7 @@ pub fn literal_to_sql(lit: &ParameterValue) -> String {
5758
pub fn determine_type_requirements(
5859
spec: &Plot,
5960
layer_type_info: &[Vec<TypeInfo>],
60-
type_names: &SqlTypeNames,
61+
dialect: &dyn SqlDialect,
6162
) -> Vec<Vec<TypeRequirement>> {
6263
use crate::plot::scale::TransformKind;
6364

@@ -123,7 +124,7 @@ pub fn determine_type_requirements(
123124

124125
// Check if this specific column needs casting
125126
if let Some(cast_target) = scale_type.required_cast_type(col_dtype, &target_dtype) {
126-
if let Some(sql_type) = type_names.for_target(cast_target) {
127+
if let Some(sql_type) = dialect.type_name_for(cast_target) {
127128
// Don't add duplicate requirements for same column
128129
if !requirements.iter().any(|r| r.column == col_name) {
129130
requirements.push(TypeRequirement {
@@ -155,7 +156,7 @@ pub fn determine_type_requirements(
155156
};
156157

157158
if needs_int_cast {
158-
if let Some(sql_type) = type_names.for_target(CastTargetType::Integer) {
159+
if let Some(sql_type) = dialect.type_name_for(CastTargetType::Integer) {
159160
// Don't add duplicate requirements for same column
160161
if !requirements.iter().any(|r| r.column == col_name) {
161162
requirements.push(TypeRequirement {

src/execute/layer.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
//! transformations, stat transforms, and post-query operations.
55
66
use crate::plot::{
7-
AestheticValue, DefaultAestheticValue, Layer, ParameterValue, Scale, Schema, SqlTypeNames,
8-
StatResult,
7+
AestheticValue, DefaultAestheticValue, Layer, ParameterValue, Scale, Schema, StatResult,
98
};
9+
use crate::reader::SqlDialect;
1010
use crate::{naming, DataFrame, GgsqlError, Result};
1111
use polars::prelude::DataType;
1212
use std::collections::{HashMap, HashSet};
@@ -194,14 +194,14 @@ pub fn literal_to_series(name: &str, lit: &ParameterValue, len: usize) -> polars
194194
/// * `layer` - The layer configuration
195195
/// * `schema` - The layer's schema (used for column dtype lookup)
196196
/// * `scales` - All resolved scales
197-
/// * `type_names` - SQL type names for the database backend
197+
/// * `dialect` - SQL dialect for the database backend
198198
pub fn apply_pre_stat_transform(
199199
query: &str,
200200
layer: &Layer,
201201
full_schema: &Schema,
202202
aesthetic_schema: &Schema,
203203
scales: &[Scale],
204-
type_names: &SqlTypeNames,
204+
dialect: &dyn SqlDialect,
205205
) -> String {
206206
let mut transform_exprs: Vec<(String, String)> = vec![];
207207
let mut transformed_columns: HashSet<String> = HashSet::new();
@@ -237,7 +237,7 @@ pub fn apply_pre_stat_transform(
237237
// Get pre-stat SQL transformation from scale type (if applicable)
238238
// Each scale type's pre_stat_transform_sql() returns None if not applicable
239239
if let Some(sql) =
240-
scale_type.pre_stat_transform_sql(&aes_col_name, &col_dtype, scale, type_names)
240+
scale_type.pre_stat_transform_sql(&aes_col_name, &col_dtype, scale, dialect)
241241
{
242242
transformed_columns.insert(aes_col_name.clone());
243243
transform_exprs.push((aes_col_name, sql));
@@ -347,7 +347,7 @@ pub fn build_layer_base_query(
347347
/// * `base_query` - The base query from build_layer_base_query
348348
/// * `schema` - The layer's schema (with min/max from base_query)
349349
/// * `scales` - All resolved scales
350-
/// * `type_names` - SQL type names for the database backend
350+
/// * `dialect` - SQL dialect for the database backend
351351
/// * `execute_query` - Function to execute queries (needed for some stat transforms)
352352
///
353353
/// # Returns
@@ -358,7 +358,7 @@ pub fn apply_layer_transforms<F>(
358358
base_query: &str,
359359
schema: &Schema,
360360
scales: &[Scale],
361-
type_names: &SqlTypeNames,
361+
dialect: &dyn SqlDialect,
362362
execute_query: &F,
363363
) -> Result<String>
364364
where
@@ -398,7 +398,7 @@ where
398398
schema,
399399
&aesthetic_schema,
400400
scales,
401-
type_names,
401+
dialect,
402402
);
403403

404404
// Build group_by columns from partition_by
@@ -427,6 +427,7 @@ where
427427
&group_by,
428428
&layer.parameters,
429429
execute_query,
430+
dialect,
430431
)?;
431432

432433
// Apply literal default remappings from geom defaults (e.g., y2 => 0.0 for bar baseline).

src/execute/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ pub struct PreparedData {
876876
/// * `reader` - A Reader implementation for executing SQL
877877
pub fn prepare_data_with_reader<R: Reader>(query: &str, reader: &R) -> Result<PreparedData> {
878878
let execute_query = |sql: &str| reader.execute_sql(sql);
879-
let type_names = reader.sql_type_names();
879+
let dialect = reader.dialect();
880880

881881
// Parse once and create SourceTree
882882
let source_tree = parser::SourceTree::new(query)?;
@@ -1014,7 +1014,7 @@ pub fn prepare_data_with_reader<R: Reader>(query: &str, reader: &R) -> Result<Pr
10141014

10151015
// Determine which columns need type casting
10161016
let type_requirements =
1017-
casting::determine_type_requirements(&specs[0], &layer_type_info, &type_names);
1017+
casting::determine_type_requirements(&specs[0], &layer_type_info, dialect);
10181018

10191019
// Update type info with post-cast dtypes
10201020
// This ensures subsequent schema extraction and scale resolution see the correct types
@@ -1083,7 +1083,7 @@ pub fn prepare_data_with_reader<R: Reader>(query: &str, reader: &R) -> Result<Pr
10831083
&layer_base_queries[idx],
10841084
&layer_schemas[idx],
10851085
&scales,
1086-
&type_names,
1086+
dialect,
10871087
&execute_query,
10881088
)?;
10891089
layer_queries.push(layer_query);

src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ pub mod writer;
4848

4949
pub mod execute;
5050

51-
pub mod utils;
5251
pub mod validate;
5352

5453
// Re-export key types for convenience

src/plot/layer/geom/bar.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use super::types::get_column_name;
77
use super::{DefaultAesthetics, DefaultParam, DefaultParamValue, GeomTrait, GeomType, StatResult};
88
use crate::naming;
99
use crate::plot::types::{DefaultAestheticValue, ParameterValue};
10+
use crate::reader::SqlDialect;
1011
use crate::{DataFrame, GgsqlError, Mappings, Result};
1112

1213
use super::types::Schema;
@@ -81,6 +82,7 @@ impl GeomTrait for Bar {
8182
group_by: &[String],
8283
_parameters: &HashMap<String, ParameterValue>,
8384
_execute_query: &dyn Fn(&str) -> Result<DataFrame>,
85+
_dialect: &dyn SqlDialect,
8486
) -> Result<StatResult> {
8587
stat_bar_count(query, schema, aesthetics, group_by)
8688
}

src/plot/layer/geom/boxplot.rs

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::{
99
geom::types::get_column_name, DefaultAestheticValue, DefaultParam, DefaultParamValue,
1010
ParameterValue, StatResult,
1111
},
12-
utils::{sql_greatest, sql_least, sql_percentile},
12+
reader::SqlDialect,
1313
DataFrame, GgsqlError, Mappings, Result,
1414
};
1515

@@ -86,8 +86,9 @@ impl GeomTrait for Boxplot {
8686
group_by: &[String],
8787
parameters: &HashMap<String, ParameterValue>,
8888
_execute_query: &dyn Fn(&str) -> Result<DataFrame>,
89+
dialect: &dyn SqlDialect,
8990
) -> Result<StatResult> {
90-
stat_boxplot(query, aesthetics, group_by, parameters)
91+
stat_boxplot(query, aesthetics, group_by, parameters, dialect)
9192
}
9293
}
9394

@@ -102,6 +103,7 @@ fn stat_boxplot(
102103
aesthetics: &Mappings,
103104
group_by: &[String],
104105
parameters: &HashMap<String, ParameterValue>,
106+
dialect: &dyn SqlDialect,
105107
) -> Result<StatResult> {
106108
let y = get_column_name(aesthetics, "pos2").ok_or_else(|| {
107109
GgsqlError::ValidationError("Boxplot requires 'y' aesthetic mapping".to_string())
@@ -147,7 +149,7 @@ fn stat_boxplot(
147149
}
148150

149151
// Query for boxplot summary statistics
150-
let summary = boxplot_sql_compute_summary(query, &groups, &value_col, coef);
152+
let summary = boxplot_sql_compute_summary(query, &groups, &value_col, coef, dialect);
151153
let stats_query = boxplot_sql_append_outliers(&summary, &groups, &value_col, query, outliers);
152154

153155
Ok(StatResult::Transformed {
@@ -162,13 +164,13 @@ fn stat_boxplot(
162164
})
163165
}
164166

165-
fn boxplot_sql_compute_summary(from: &str, groups: &[String], value: &str, coef: &f64) -> String {
167+
fn boxplot_sql_compute_summary(from: &str, groups: &[String], value: &str, coef: &f64, dialect: &dyn SqlDialect) -> String {
166168
let groups_str = groups.join(", ");
167-
let lower_expr = sql_greatest(&[&format!("q1 - {coef} * (q3 - q1)"), "min"]);
168-
let upper_expr = sql_least(&[&format!("q3 + {coef} * (q3 - q1)"), "max"]);
169-
let q1 = sql_percentile(value, 0.25, from, groups);
170-
let median = sql_percentile(value, 0.50, from, groups);
171-
let q3 = sql_percentile(value, 0.75, from, groups);
169+
let lower_expr = dialect.sql_greatest(&[&format!("q1 - {coef} * (q3 - q1)"), "min"]);
170+
let upper_expr = dialect.sql_least(&[&format!("q3 + {coef} * (q3 - q1)"), "max"]);
171+
let q1 = dialect.sql_percentile(value, 0.25, from, groups);
172+
let median = dialect.sql_percentile(value, 0.50, from, groups);
173+
let q3 = dialect.sql_percentile(value, 0.75, from, groups);
172174
format!(
173175
"SELECT
174176
*,
@@ -293,6 +295,7 @@ fn boxplot_sql_append_outliers(
293295
mod tests {
294296
use super::*;
295297
use crate::plot::AestheticValue;
298+
use crate::reader::AnsiDialect;
296299

297300
// ==================== Helper Functions ====================
298301

@@ -314,7 +317,7 @@ mod tests {
314317
#[test]
315318
fn test_sql_compute_summary_basic() {
316319
let groups = vec!["category".to_string()];
317-
let result = boxplot_sql_compute_summary("data", &groups, "value", &1.5);
320+
let result = boxplot_sql_compute_summary("data", &groups, "value", &1.5, &AnsiDialect);
318321
assert!(result.contains("NTILE(4) OVER (ORDER BY value)"));
319322
assert!(result.contains("AS q1"));
320323
assert!(result.contains("AS median"));
@@ -330,15 +333,15 @@ mod tests {
330333
#[test]
331334
fn test_sql_compute_summary_multiple_groups() {
332335
let groups = vec!["cat".to_string(), "region".to_string()];
333-
let result = boxplot_sql_compute_summary("tbl", &groups, "val", &1.5);
336+
let result = boxplot_sql_compute_summary("tbl", &groups, "val", &1.5, &AnsiDialect);
334337
assert!(result.contains("GROUP BY cat, region"));
335338
assert!(result.contains("NTILE(4) OVER (ORDER BY val)"));
336339
}
337340

338341
#[test]
339342
fn test_sql_compute_summary_custom_coef() {
340343
let groups = vec!["pos1".to_string()];
341-
let result = boxplot_sql_compute_summary("q", &groups, "pos2", &2.5);
344+
let result = boxplot_sql_compute_summary("q", &groups, "pos2", &2.5, &AnsiDialect);
342345
assert!(result.contains("2.5"));
343346
assert!(
344347
result.contains("(CASE WHEN (q1 - 2.5 * (q3 - q1)) >= (min) THEN (q1 - 2.5 * (q3 - q1)) ELSE (min) END)")
@@ -364,11 +367,11 @@ mod tests {
364367
#[test]
365368
fn test_boxplot_sql_compute_summary_single_group() {
366369
let groups = vec!["category".to_string()];
367-
let result = boxplot_sql_compute_summary("SELECT * FROM sales", &groups, "price", &1.5);
370+
let result = boxplot_sql_compute_summary("SELECT * FROM sales", &groups, "price", &1.5, &AnsiDialect);
368371

369-
let q1 = sql_percentile("price", 0.25, "SELECT * FROM sales", &groups);
370-
let median = sql_percentile("price", 0.50, "SELECT * FROM sales", &groups);
371-
let q3 = sql_percentile("price", 0.75, "SELECT * FROM sales", &groups);
372+
let q1 = AnsiDialect.sql_percentile("price", 0.25, "SELECT * FROM sales", &groups);
373+
let median = AnsiDialect.sql_percentile("price", 0.50, "SELECT * FROM sales", &groups);
374+
let q3 = AnsiDialect.sql_percentile("price", 0.75, "SELECT * FROM sales", &groups);
372375
let expected = format!(
373376
r#"SELECT
374377
*,
@@ -394,11 +397,11 @@ mod tests {
394397
#[test]
395398
fn test_boxplot_sql_compute_summary_multiple_groups() {
396399
let groups = vec!["region".to_string(), "product".to_string()];
397-
let result = boxplot_sql_compute_summary("SELECT * FROM data", &groups, "revenue", &1.5);
400+
let result = boxplot_sql_compute_summary("SELECT * FROM data", &groups, "revenue", &1.5, &AnsiDialect);
398401

399-
let q1 = sql_percentile("revenue", 0.25, "SELECT * FROM data", &groups);
400-
let median = sql_percentile("revenue", 0.50, "SELECT * FROM data", &groups);
401-
let q3 = sql_percentile("revenue", 0.75, "SELECT * FROM data", &groups);
402+
let q1 = AnsiDialect.sql_percentile("revenue", 0.25, "SELECT * FROM data", &groups);
403+
let median = AnsiDialect.sql_percentile("revenue", 0.50, "SELECT * FROM data", &groups);
404+
let q3 = AnsiDialect.sql_percentile("revenue", 0.75, "SELECT * FROM data", &groups);
402405
let expected = format!(
403406
r#"SELECT
404407
*,
@@ -507,7 +510,7 @@ mod tests {
507510
);
508511
parameters.insert("outliers".to_string(), ParameterValue::Boolean(true));
509512

510-
let result = stat_boxplot("SELECT * FROM data", &aesthetics, &groups, &parameters);
513+
let result = stat_boxplot("SELECT * FROM data", &aesthetics, &groups, &parameters, &AnsiDialect);
511514

512515
assert!(result.is_err());
513516
assert!(result.unwrap_err().to_string().contains("coef"));
@@ -522,7 +525,7 @@ mod tests {
522525
parameters.insert("outliers".to_string(), ParameterValue::Boolean(true));
523526
// Missing coef
524527

525-
let result = stat_boxplot("SELECT * FROM data", &aesthetics, &groups, &parameters);
528+
let result = stat_boxplot("SELECT * FROM data", &aesthetics, &groups, &parameters, &AnsiDialect);
526529

527530
assert!(result.is_err());
528531
assert!(result.unwrap_err().to_string().contains("coef"));
@@ -540,7 +543,7 @@ mod tests {
540543
ParameterValue::String("yes".to_string()),
541544
);
542545

543-
let result = stat_boxplot("SELECT * FROM data", &aesthetics, &groups, &parameters);
546+
let result = stat_boxplot("SELECT * FROM data", &aesthetics, &groups, &parameters, &AnsiDialect);
544547

545548
assert!(result.is_err());
546549
assert!(result.unwrap_err().to_string().contains("outliers"));
@@ -555,7 +558,7 @@ mod tests {
555558
parameters.insert("coef".to_string(), ParameterValue::Number(1.5));
556559
// Missing outliers
557560

558-
let result = stat_boxplot("SELECT * FROM data", &aesthetics, &groups, &parameters);
561+
let result = stat_boxplot("SELECT * FROM data", &aesthetics, &groups, &parameters, &AnsiDialect);
559562

560563
assert!(result.is_err());
561564
assert!(result.unwrap_err().to_string().contains("outliers"));

0 commit comments

Comments
 (0)