Skip to content

Commit 5d68d92

Browse files
committed
WIP: Quoting is more fun
1 parent ae29fa0 commit 5d68d92

8 files changed

Lines changed: 123 additions & 99 deletions

File tree

src/plot/layer/geom/boxplot.rs

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,16 @@ fn boxplot_sql_compute_summary(
171171
coef: &f64,
172172
dialect: &dyn SqlDialect,
173173
) -> String {
174-
let groups_str = groups.join(", ");
174+
let quoted_groups: Vec<String> = groups.iter().map(|g| format!("\"{}\"", g)).collect();
175+
let groups_str = quoted_groups.join(", ");
175176
let lower_expr = dialect.sql_greatest(&[&format!("q1 - {coef} * (q3 - q1)"), "min"]);
176177
let upper_expr = dialect.sql_least(&[&format!("q3 + {coef} * (q3 - q1)"), "max"]);
177178
let q1 = dialect.sql_percentile(value, 0.25, from, groups);
178179
let median = dialect.sql_percentile(value, 0.50, from, groups);
179180
let q3 = dialect.sql_percentile(value, 0.75, from, groups);
180181
let qt = "\"__ggsql_qt__\"";
181182
let fn_alias = "\"__ggsql_fn__\"";
183+
let quoted_value = format!("\"{}\"", value);
182184
format!(
183185
"SELECT
184186
*,
@@ -199,7 +201,7 @@ fn boxplot_sql_compute_summary(
199201
lower_expr = lower_expr,
200202
upper_expr = upper_expr,
201203
groups = groups_str,
202-
value = value,
204+
value = quoted_value,
203205
from = from,
204206
q1 = q1,
205207
median = median,
@@ -211,10 +213,12 @@ fn boxplot_sql_filter_outliers(groups: &[String], value: &str, from: &str) -> St
211213
let mut join_pairs = Vec::new();
212214
let mut keep_columns = Vec::new();
213215
for column in groups {
214-
join_pairs.push(format!("raw.{} = summary.{}", column, column));
215-
keep_columns.push(format!("raw.{}", column));
216+
let quoted = format!("\"{}\"", column);
217+
join_pairs.push(format!("raw.{} = summary.{}", quoted, quoted));
218+
keep_columns.push(format!("raw.{}", quoted));
216219
}
217220

221+
let quoted_value = format!("\"{}\"", value);
218222
// We're joining outliers with the summary to use the lower/upper whisker
219223
// values as a filter
220224
format!(
@@ -225,7 +229,7 @@ fn boxplot_sql_filter_outliers(groups: &[String], value: &str, from: &str) -> St
225229
FROM ({from}) raw
226230
JOIN summary ON {pairs}
227231
WHERE raw.{value} NOT BETWEEN summary.lower AND summary.upper",
228-
value = value,
232+
value = quoted_value,
229233
groups = keep_columns.join(", "),
230234
pairs = join_pairs.join(" AND "),
231235
from = from
@@ -243,7 +247,8 @@ fn boxplot_sql_append_outliers(
243247
let value2_name = naming::stat_column("value2");
244248
let type_name = naming::stat_column("type");
245249

246-
let groups_str = groups.join(", ");
250+
let quoted_groups: Vec<String> = groups.iter().map(|g| format!("\"{}\"", g)).collect();
251+
let groups_str = quoted_groups.join(", ");
247252

248253
// Helper to build visual-element rows from summary table
249254
// Each row type maps to one visual element with y and yend where needed
@@ -326,14 +331,14 @@ mod tests {
326331
fn test_sql_compute_summary_basic() {
327332
let groups = vec!["category".to_string()];
328333
let result = boxplot_sql_compute_summary("data", &groups, "value", &1.5, &AnsiDialect);
329-
assert!(result.contains("NTILE(4) OVER (ORDER BY value)"));
334+
assert!(result.contains("NTILE(4) OVER (ORDER BY \"value\")"));
330335
assert!(result.contains("AS q1"));
331336
assert!(result.contains("AS median"));
332337
assert!(result.contains("AS q3"));
333-
assert!(result.contains("MIN(value) AS min"));
334-
assert!(result.contains("MAX(value) AS max"));
335-
assert!(result.contains("WHERE value IS NOT NULL"));
336-
assert!(result.contains("GROUP BY category"));
338+
assert!(result.contains("MIN(\"value\") AS min"));
339+
assert!(result.contains("MAX(\"value\") AS max"));
340+
assert!(result.contains("WHERE \"value\" IS NOT NULL"));
341+
assert!(result.contains("GROUP BY \"category\""));
337342
assert!(result.contains("CASE WHEN (q1 - 1.5"));
338343
assert!(result.contains("CASE WHEN (q3 + 1.5"));
339344
}
@@ -342,8 +347,8 @@ mod tests {
342347
fn test_sql_compute_summary_multiple_groups() {
343348
let groups = vec!["cat".to_string(), "region".to_string()];
344349
let result = boxplot_sql_compute_summary("tbl", &groups, "val", &1.5, &AnsiDialect);
345-
assert!(result.contains("GROUP BY cat, region"));
346-
assert!(result.contains("NTILE(4) OVER (ORDER BY val)"));
350+
assert!(result.contains("GROUP BY \"cat\", \"region\""));
351+
assert!(result.contains("NTILE(4) OVER (ORDER BY \"val\")"));
347352
}
348353

349354
#[test]
@@ -364,8 +369,8 @@ mod tests {
364369
let groups = vec!["cat".to_string(), "region".to_string()];
365370
let result = boxplot_sql_filter_outliers(&groups, "value", "raw_data");
366371
assert!(result.contains("JOIN summary ON"));
367-
assert!(result.contains("raw.cat = summary.cat"));
368-
assert!(result.contains("raw.region = summary.region"));
372+
assert!(result.contains("raw.\"cat\" = summary.\"cat\""));
373+
assert!(result.contains("raw.\"region\" = summary.\"region\""));
369374
assert!(result.contains("NOT BETWEEN summary.lower AND summary.upper"));
370375
assert!(result.contains("'outlier' AS type"));
371376
}
@@ -393,15 +398,15 @@ mod tests {
393398
(CASE WHEN (q3 + 1.5 * (q3 - q1)) <= (max) THEN (q3 + 1.5 * (q3 - q1)) ELSE (max) END) AS upper
394399
FROM (
395400
SELECT
396-
category,
397-
MIN(price) AS min,
398-
MAX(price) AS max,
401+
"category",
402+
MIN("price") AS min,
403+
MAX("price") AS max,
399404
{q1} AS q1,
400405
{median} AS median,
401406
{q3} AS q3
402407
FROM (SELECT * FROM sales) AS "__ggsql_qt__"
403-
WHERE price IS NOT NULL
404-
GROUP BY category
408+
WHERE "price" IS NOT NULL
409+
GROUP BY "category"
405410
) AS "__ggsql_fn__""#
406411
);
407412

@@ -429,15 +434,15 @@ mod tests {
429434
(CASE WHEN (q3 + 1.5 * (q3 - q1)) <= (max) THEN (q3 + 1.5 * (q3 - q1)) ELSE (max) END) AS upper
430435
FROM (
431436
SELECT
432-
region, product,
433-
MIN(revenue) AS min,
434-
MAX(revenue) AS max,
437+
"region", "product",
438+
MIN("revenue") AS min,
439+
MAX("revenue") AS max,
435440
{q1} AS q1,
436441
{median} AS median,
437442
{q3} AS q3
438443
FROM (SELECT * FROM data) AS "__ggsql_qt__"
439-
WHERE revenue IS NOT NULL
440-
GROUP BY region, product
444+
WHERE "revenue" IS NOT NULL
445+
GROUP BY "region", "product"
441446
) AS "__ggsql_fn__""#
442447
);
443448

@@ -501,8 +506,8 @@ mod tests {
501506
let raw = "(SELECT * FROM raw_data)";
502507
let result = boxplot_sql_append_outliers(summary, &groups, "val", raw, &true);
503508

504-
// Verify all groups are present
505-
assert!(result.contains("cat, region, year"));
509+
// Verify all groups are present (quoted)
510+
assert!(result.contains("\"cat\", \"region\", \"year\""));
506511

507512
// Check structure
508513
assert!(result.contains("WITH"));
@@ -511,9 +516,9 @@ mod tests {
511516

512517
// Verify outlier join conditions for all groups
513518
let outlier_section = result.split("outliers AS").nth(1).unwrap();
514-
assert!(outlier_section.contains("raw.cat = summary.cat"));
515-
assert!(outlier_section.contains("raw.region = summary.region"));
516-
assert!(outlier_section.contains("raw.year = summary.year"));
519+
assert!(outlier_section.contains("raw.\"cat\" = summary.\"cat\""));
520+
assert!(outlier_section.contains("raw.\"region\" = summary.\"region\""));
521+
assert!(outlier_section.contains("raw.\"year\" = summary.\"year\""));
517522
}
518523

519524
// ==================== Parameter Validation Tests ====================

src/plot/layer/geom/density.rs

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,14 @@ fn compute_range_sql(
182182
from: &str,
183183
execute: &dyn Fn(&str) -> crate::Result<polars::prelude::DataFrame>,
184184
) -> Result<(f64, f64)> {
185+
let quoted_value = format!("\"{}\"", value);
185186
let query = format!(
186187
"SELECT
187188
MIN({value}) AS min,
188189
MAX({value}) AS max
189190
FROM ({from})
190191
WHERE {value} IS NOT NULL",
191-
value = value,
192+
value = quoted_value,
192193
from = from
193194
);
194195
let result = execute(&query)?;
@@ -234,7 +235,8 @@ fn density_sql_bandwidth(
234235
) -> String {
235236
let mut group_by = String::new();
236237
let mut comma = String::new();
237-
let groups_str = groups.join(", ");
238+
let quoted_groups: Vec<String> = groups.iter().map(|g| format!("\"{}\"", g)).collect();
239+
let groups_str = quoted_groups.join(", ");
238240

239241
if !groups_str.is_empty() {
240242
group_by = format!("GROUP BY {}", groups_str);
@@ -266,6 +268,7 @@ fn density_sql_bandwidth(
266268
};
267269
return cte;
268270
}
271+
let quoted_value = format!("\"{}\"", value);
269272
format!(
270273
"WITH RECURSIVE
271274
bandwidth AS (
@@ -277,7 +280,7 @@ fn density_sql_bandwidth(
277280
{group_by}
278281
)",
279282
rule = silverman_rule(adjust, value, from, groups, dialect),
280-
value = value,
283+
value = quoted_value,
281284
group_by = group_by,
282285
groups_str = groups_str,
283286
comma = comma,
@@ -296,7 +299,8 @@ fn silverman_rule(
296299
// The query computes Silverman's rule of thumb (R's `stats::bw.nrd0()`).
297300
// We absorb the adjustment in the 0.9 multiplier of the rule
298301
let adjust = 0.9 * adjust;
299-
let stddev = format!("SQRT(AVG({v}*{v}) - AVG({v})*AVG({v}))", v = value_column);
302+
let v = format!("\"{}\"", value_column);
303+
let stddev = format!("SQRT(AVG({v}*{v}) - AVG({v})*AVG({v}))", v = v);
300304
let q75 = dialect.sql_percentile(value_column, 0.75, from, groups);
301305
let q25 = dialect.sql_percentile(value_column, 0.25, from, groups);
302306
let iqr = format!("({q75} - {q25}) / 1.34");
@@ -364,22 +368,24 @@ fn choose_kde_kernel(parameters: &HashMap<String, ParameterValue>) -> Result<Str
364368
fn build_data_cte(value: &str, weight: Option<&str>, from: &str, group_by: &[String]) -> String {
365369
// Include weight column if provided, otherwise default to 1.0
366370
let weight_col = if let Some(w) = weight {
367-
format!(", {} AS weight", w)
371+
format!(", \"{}\" AS weight", w)
368372
} else {
369373
", 1.0 AS weight".to_string()
370374
};
371375

376+
let quoted_value = format!("\"{}\"", value);
372377
// Only filter out nulls in value column, keep NULLs in group columns
373-
let filter_valid = format!("{} IS NOT NULL", value);
378+
let filter_valid = format!("{} IS NOT NULL", quoted_value);
374379

380+
let quoted_groups: Vec<String> = group_by.iter().map(|g| format!("\"{}\"", g)).collect();
375381
format!(
376382
"data AS (
377383
SELECT {groups}{value} AS val{weight_col}
378384
FROM ({from})
379385
WHERE {filter_valid}
380386
)",
381-
groups = with_trailing_comma(&group_by.join(", ")),
382-
value = value,
387+
groups = with_trailing_comma(&quoted_groups.join(", ")),
388+
value = quoted_value,
383389
weight_col = weight_col,
384390
from = from,
385391
filter_valid = filter_valid
@@ -419,7 +425,8 @@ fn build_grid_cte(
419425
);
420426
}
421427

422-
let groups = groups.join(", ");
428+
let quoted_groups: Vec<String> = groups.iter().map(|g| format!("\"{}\"", g)).collect();
429+
let groups = quoted_groups.join(", ");
423430
format!(
424431
"{seq}, grid AS (
425432
SELECT
@@ -451,7 +458,7 @@ fn compute_density(
451458
} else {
452459
group_by
453460
.iter()
454-
.map(|g| format!("data.{col} IS NOT DISTINCT FROM bandwidth.{col}", col = g))
461+
.map(|g| format!("data.\"{}\" IS NOT DISTINCT FROM bandwidth.\"{}\"", g, g))
455462
.collect::<Vec<String>>()
456463
.join(" AND ")
457464
};
@@ -462,7 +469,7 @@ fn compute_density(
462469
} else {
463470
let grid_data_conds: Vec<String> = group_by
464471
.iter()
465-
.map(|g| format!("grid.{col} IS NOT DISTINCT FROM data.{col}", col = g))
472+
.map(|g| format!("grid.\"{}\" IS NOT DISTINCT FROM data.\"{}\"", g, g))
466473
.collect();
467474
format!("WHERE {}", grid_data_conds.join(" AND "))
468475
};
@@ -476,7 +483,7 @@ fn compute_density(
476483
);
477484

478485
// Build group-related SQL fragments
479-
let grid_groups: Vec<String> = group_by.iter().map(|g| format!("grid.{}", g)).collect();
486+
let grid_groups: Vec<String> = group_by.iter().map(|g| format!("grid.\"{}\"", g)).collect();
480487
let aggregation = format!(
481488
"GROUP BY grid.x{grid_group_by}
482489
ORDER BY grid.x{grid_group_by}",
@@ -486,7 +493,8 @@ fn compute_density(
486493
let groups = if group_by.is_empty() {
487494
String::new()
488495
} else {
489-
format!("{},", group_by.join(", "))
496+
let quoted: Vec<String> = group_by.iter().map(|g| format!("\"{}\"", g)).collect();
497+
format!("{},", quoted.join(", "))
490498
};
491499

492500
// Generate the density computation query
@@ -546,9 +554,9 @@ mod tests {
546554

547555
let expected = r#"WITH RECURSIVE bandwidth AS (SELECT 0.5 AS bw),
548556
data AS (
549-
SELECT x AS val, 1.0 AS weight
557+
SELECT "x" AS val, 1.0 AS weight
550558
FROM (SELECT x FROM (VALUES (1.0), (2.0), (3.0)) AS t(x))
551-
WHERE x IS NOT NULL
559+
WHERE "x" IS NOT NULL
552560
),
553561
"__ggsql_base__"(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM "__ggsql_base__" WHERE n < 7),"__ggsql_seq__"(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM "__ggsql_base__" a, "__ggsql_base__" b, "__ggsql_base__" c WHERE a.n * 64 + b.n * 8 + c.n < 512),
554562
grid AS (
@@ -607,37 +615,37 @@ mod tests {
607615
let kernel = choose_kde_kernel(&parameters).expect("kernel should be valid");
608616
let sql = compute_density("x", &groups, kernel, &bw_cte, &data_cte, &grid_cte);
609617

610-
let expected = r#"WITH RECURSIVE bandwidth AS (SELECT 0.5 AS bw, region, category FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) GROUP BY region, category),
618+
let expected = r#"WITH RECURSIVE bandwidth AS (SELECT 0.5 AS bw, "region", "category" FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category)) GROUP BY "region", "category"),
611619
data AS (
612-
SELECT region, category, x AS val, 1.0 AS weight
620+
SELECT "region", "category", "x" AS val, 1.0 AS weight
613621
FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category))
614-
WHERE x IS NOT NULL
622+
WHERE "x" IS NOT NULL
615623
),
616624
"__ggsql_base__"(n) AS (SELECT 0 UNION ALL SELECT n + 1 FROM "__ggsql_base__" WHERE n < 7),"__ggsql_seq__"(n) AS (SELECT CAST(a.n * 64 + b.n * 8 + c.n AS REAL) AS n FROM "__ggsql_base__" a, "__ggsql_base__" b, "__ggsql_base__" c WHERE a.n * 64 + b.n * 8 + c.n < 512),
617625
grid AS (
618626
SELECT
619-
region, category,
627+
"region", "category",
620628
-11 + ("__ggsql_seq__".n * 22 / 511) AS x
621629
FROM "__ggsql_seq__"
622-
CROSS JOIN (SELECT DISTINCT region, category FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category))) AS groups
630+
CROSS JOIN (SELECT DISTINCT "region", "category" FROM (SELECT x, region, category FROM (VALUES (1.0, 'A', 'X'), (2.0, 'B', 'Y')) AS t(x, region, category))) AS groups
623631
)
624632
SELECT
625633
"__ggsql_stat_x",
626-
region, category,
634+
"region", "category",
627635
"__ggsql_stat_intensity",
628636
"__ggsql_stat_intensity" / "__norm" AS "__ggsql_stat_density"
629637
FROM (
630638
SELECT
631639
grid.x AS "__ggsql_stat_x",
632-
grid.region, grid.category,
640+
grid."region", grid."category",
633641
SUM(data.weight * ((EXP(-0.5 * (grid.x - data.val) * (grid.x - data.val) / (bandwidth.bw * bandwidth.bw))) * 0.3989422804014327)) / MIN(bandwidth.bw) AS "__ggsql_stat_intensity",
634642
SUM(data.weight) AS "__norm"
635643
FROM data
636-
INNER JOIN bandwidth ON data.region IS NOT DISTINCT FROM bandwidth.region AND data.category IS NOT DISTINCT FROM bandwidth.category
644+
INNER JOIN bandwidth ON data."region" IS NOT DISTINCT FROM bandwidth."region" AND data."category" IS NOT DISTINCT FROM bandwidth."category"
637645
CROSS JOIN grid
638-
WHERE grid.region IS NOT DISTINCT FROM data.region AND grid.category IS NOT DISTINCT FROM data.category
639-
GROUP BY grid.x, grid.region, grid.category
640-
ORDER BY grid.x, grid.region, grid.category
646+
WHERE grid."region" IS NOT DISTINCT FROM data."region" AND grid."category" IS NOT DISTINCT FROM data."category"
647+
GROUP BY grid.x, grid."region", grid."category"
648+
ORDER BY grid.x, grid."region", grid."category"
641649
)"#;
642650

643651
// Normalize whitespace for comparison
@@ -718,7 +726,7 @@ mod tests {
718726

719727
// Verify SQL uses NTILE-based percentile subqueries with grouping
720728
assert!(bw_cte.contains("NTILE(4)"));
721-
assert!(bw_cte.contains("GROUP BY region"));
729+
assert!(bw_cte.contains("GROUP BY \"region\""));
722730
let expected_rule = silverman_rule(1.0, "x", query, &groups, &AnsiDialect);
723731
assert!(normalize(&bw_cte).contains(&normalize(&expected_rule)));
724732

0 commit comments

Comments
 (0)