Skip to content

Commit 5669004

Browse files
cpsievertclaude
andcommitted
fix: compute position stacking per-facet-panel instead of globally (#244)
Stacking `.over()` partitions now include facet columns, so cumulative sums reset within each facet panel. Previously, bars in later panels stacked on top of cumulative values from earlier panels. Also applies to Fill and Center stacking modes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3bbed9d commit 5669004

2 files changed

Lines changed: 94 additions & 4 deletions

File tree

src/execute/position.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,4 +322,80 @@ mod tests {
322322
assert!((-0.3..=0.3).contains(&v));
323323
}
324324
}
325+
326+
#[test]
327+
fn test_stack_resets_per_facet_panel() {
328+
// Stacking should compute independently within each facet panel.
329+
// Without this, bars in the second facet panel stack on top of
330+
// cumulative values from the first panel (see issue #244).
331+
//
332+
// Two facet panels (F1, F2) each with the same x="A" and two
333+
// fill groups (X, Y). Stacking within each panel should start from 0.
334+
let df = df! {
335+
"__ggsql_aes_pos1__" => ["A", "A", "A", "A"],
336+
"__ggsql_aes_pos2__" => [10.0, 20.0, 30.0, 40.0],
337+
"__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0],
338+
"__ggsql_aes_fill__" => ["X", "Y", "X", "Y"],
339+
"__ggsql_aes_facet1__" => ["F1", "F1", "F2", "F2"],
340+
}
341+
.unwrap();
342+
343+
let mut layer = crate::plot::Layer::new(Geom::bar());
344+
layer.mappings = {
345+
let mut m = Mappings::new();
346+
m.insert("pos1", AestheticValue::standard_column("__ggsql_aes_pos1__"));
347+
m.insert("pos2", AestheticValue::standard_column("__ggsql_aes_pos2__"));
348+
m.insert("pos2end", AestheticValue::standard_column("__ggsql_aes_pos2end__"));
349+
m.insert("fill", AestheticValue::standard_column("__ggsql_aes_fill__"));
350+
m.insert("facet1", AestheticValue::standard_column("__ggsql_aes_facet1__"));
351+
m
352+
};
353+
layer.partition_by = vec![
354+
"__ggsql_aes_fill__".to_string(),
355+
"__ggsql_aes_facet1__".to_string(),
356+
];
357+
layer.position = Position::stack();
358+
layer.data_key = Some("__ggsql_layer_0__".to_string());
359+
360+
let mut spec = Plot::new();
361+
spec.scales.push(make_discrete_scale("pos1"));
362+
spec.scales.push(make_continuous_scale("pos2"));
363+
let mut data_map = HashMap::new();
364+
data_map.insert("__ggsql_layer_0__".to_string(), df);
365+
366+
let mut spec_with_layer = spec;
367+
spec_with_layer.layers.push(layer);
368+
369+
apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap();
370+
371+
let result_df = data_map.get("__ggsql_layer_0__").unwrap();
372+
373+
// Sort by facet then fill so we can assert in predictable order
374+
let result_df = result_df
375+
.clone()
376+
.lazy()
377+
.sort_by_exprs(
378+
[col("__ggsql_aes_facet1__"), col("__ggsql_aes_fill__")],
379+
SortMultipleOptions::default(),
380+
)
381+
.collect()
382+
.unwrap();
383+
384+
let pos2 = result_df.column("__ggsql_aes_pos2__").unwrap().f64().unwrap();
385+
let pos2end = result_df.column("__ggsql_aes_pos2end__").unwrap().f64().unwrap();
386+
387+
let pos2_vals: Vec<f64> = pos2.into_iter().flatten().collect();
388+
let pos2end_vals: Vec<f64> = pos2end.into_iter().flatten().collect();
389+
390+
// Expected (sorted by facet, fill):
391+
// F1/X: pos2end=0, pos2=10 (first in panel, starts at 0)
392+
// F1/Y: pos2end=10, pos2=30 (stacks on X)
393+
// F2/X: pos2end=0, pos2=30 (first in panel, should reset to 0)
394+
// F2/Y: pos2end=30, pos2=70 (stacks on X)
395+
assert_eq!(
396+
pos2end_vals[2], 0.0,
397+
"F2 panel first bar should start at 0, not carry over from F1. pos2end={:?}, pos2={:?}",
398+
pos2end_vals, pos2_vals
399+
);
400+
}
325401
}

src/plot/layer/position/stack.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,22 +263,36 @@ fn apply_stack(df: DataFrame, layer: &Layer, spec: &Plot, mode: StackMode) -> Re
263263
// 2. stack_end_col = lag(stack_col, 1, 0) - the bar bottom/start (previous stack top)
264264
// The cumsum naturally stacks across the grouping column values
265265

266+
// Build the partition columns for .over(): group column + facet columns.
267+
// Facet columns must be included so stacking resets per facet panel,
268+
// matching ggplot2 where position adjustments are computed per-panel.
269+
let mut over_cols: Vec<Expr> = vec![col(&group_col)];
270+
for partition_col in &layer.partition_by {
271+
if naming::is_aesthetic_column(partition_col) {
272+
if let Some(aes) = naming::extract_aesthetic_name(partition_col) {
273+
if aes.starts_with("facet") {
274+
over_cols.push(col(partition_col));
275+
}
276+
}
277+
}
278+
}
279+
266280
// Treat NA heights as 0 for stacking
267281
// Compute cumulative sums (shared by all modes)
268282
let lf = lf
269283
.with_column(col(&stack_col).fill_null(lit(0.0)).alias(&stack_col))
270284
.with_column(
271285
col(&stack_col)
272286
.cum_sum(false)
273-
.over([col(&group_col)])
287+
.over(&over_cols)
274288
.alias("__cumsum__"),
275289
)
276290
.with_column(
277291
col(&stack_col)
278292
.cum_sum(false)
279293
.shift(lit(1))
280294
.fill_null(lit(0.0))
281-
.over([col(&group_col)])
295+
.over(&over_cols)
282296
.alias("__cumsum_lag__"),
283297
);
284298

@@ -290,15 +304,15 @@ fn apply_stack(df: DataFrame, layer: &Layer, spec: &Plot, mode: StackMode) -> Re
290304
vec!["__cumsum__", "__cumsum_lag__"],
291305
),
292306
StackMode::Fill(target) => {
293-
let total = col(&stack_col).sum().over([col(&group_col)]);
307+
let total = col(&stack_col).sum().over(&over_cols);
294308
(
295309
(col("__cumsum__") / total.clone() * lit(target)).alias(&stack_col),
296310
(col("__cumsum_lag__") / total * lit(target)).alias(&stack_end_col),
297311
vec!["__cumsum__", "__cumsum_lag__"],
298312
)
299313
}
300314
StackMode::Center => {
301-
let half_total = col(&stack_col).sum().over([col(&group_col)]) / lit(2.0);
315+
let half_total = col(&stack_col).sum().over(&over_cols) / lit(2.0);
302316
(
303317
(col("__cumsum__") - half_total.clone()).alias(&stack_col),
304318
(col("__cumsum_lag__") - half_total).alias(&stack_end_col),

0 commit comments

Comments
 (0)