Skip to content

Commit 00a0947

Browse files
authored
fix: compute position stacking per-facet-panel instead of globally (#245)
1 parent 5bec040 commit 00a0947

5 files changed

Lines changed: 234 additions & 11 deletions

File tree

src/execute/position.rs

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ pub fn apply_position_adjustments(
5858
#[cfg(test)]
5959
mod tests {
6060
use super::*;
61+
use crate::plot::facet::{Facet, FacetLayout};
6162
use crate::plot::layer::{Geom, Position};
6263
use crate::plot::{AestheticValue, Mappings, ParameterValue, Scale, ScaleType};
6364
use polars::prelude::*;
@@ -322,4 +323,182 @@ mod tests {
322323
assert!((-0.3..=0.3).contains(&v));
323324
}
324325
}
326+
327+
#[test]
328+
fn test_stack_resets_per_facet_panel() {
329+
// Stacking should compute independently within each facet panel.
330+
// Without this, bars in the second facet panel stack on top of
331+
// cumulative values from the first panel (see issue #244).
332+
//
333+
// Two facet panels (F1, F2) each with the same x="A" and two
334+
// fill groups (X, Y). Stacking within each panel should start from 0.
335+
let df = df! {
336+
"__ggsql_aes_pos1__" => ["A", "A", "A", "A"],
337+
"__ggsql_aes_pos2__" => [10.0, 20.0, 30.0, 40.0],
338+
"__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0],
339+
"__ggsql_aes_fill__" => ["X", "Y", "X", "Y"],
340+
"__ggsql_aes_facet1__" => ["F1", "F1", "F2", "F2"],
341+
}
342+
.unwrap();
343+
344+
let mut layer = crate::plot::Layer::new(Geom::bar());
345+
layer.mappings = {
346+
let mut m = Mappings::new();
347+
m.insert(
348+
"pos1",
349+
AestheticValue::standard_column("__ggsql_aes_pos1__"),
350+
);
351+
m.insert(
352+
"pos2",
353+
AestheticValue::standard_column("__ggsql_aes_pos2__"),
354+
);
355+
m.insert(
356+
"pos2end",
357+
AestheticValue::standard_column("__ggsql_aes_pos2end__"),
358+
);
359+
m.insert(
360+
"fill",
361+
AestheticValue::standard_column("__ggsql_aes_fill__"),
362+
);
363+
m.insert(
364+
"facet1",
365+
AestheticValue::standard_column("__ggsql_aes_facet1__"),
366+
);
367+
m
368+
};
369+
layer.partition_by = vec![
370+
"__ggsql_aes_fill__".to_string(),
371+
"__ggsql_aes_facet1__".to_string(),
372+
];
373+
layer.position = Position::stack();
374+
layer.data_key = Some("__ggsql_layer_0__".to_string());
375+
376+
let mut spec = Plot::new();
377+
spec.scales.push(make_discrete_scale("pos1"));
378+
spec.scales.push(make_continuous_scale("pos2"));
379+
spec.facet = Some(Facet::new(FacetLayout::Wrap {
380+
variables: vec!["facet_var".to_string()],
381+
}));
382+
let mut data_map = HashMap::new();
383+
data_map.insert("__ggsql_layer_0__".to_string(), df);
384+
385+
let mut spec_with_layer = spec;
386+
spec_with_layer.layers.push(layer);
387+
388+
apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap();
389+
390+
let result_df = data_map.get("__ggsql_layer_0__").unwrap();
391+
392+
// Sort by facet then fill so we can assert in predictable order
393+
let result_df = result_df
394+
.clone()
395+
.lazy()
396+
.sort_by_exprs(
397+
[col("__ggsql_aes_facet1__"), col("__ggsql_aes_fill__")],
398+
SortMultipleOptions::default(),
399+
)
400+
.collect()
401+
.unwrap();
402+
403+
let pos2 = result_df
404+
.column("__ggsql_aes_pos2__")
405+
.unwrap()
406+
.f64()
407+
.unwrap();
408+
let pos2end = result_df
409+
.column("__ggsql_aes_pos2end__")
410+
.unwrap()
411+
.f64()
412+
.unwrap();
413+
414+
let pos2_vals: Vec<f64> = pos2.into_iter().flatten().collect();
415+
let pos2end_vals: Vec<f64> = pos2end.into_iter().flatten().collect();
416+
417+
// Expected (sorted by facet, fill):
418+
// F1/X: pos2end=0, pos2=10 (first in panel, starts at 0)
419+
// F1/Y: pos2end=10, pos2=30 (stacks on X)
420+
// F2/X: pos2end=0, pos2=30 (first in panel, should reset to 0)
421+
// F2/Y: pos2end=30, pos2=70 (stacks on X)
422+
assert_eq!(
423+
pos2end_vals[2], 0.0,
424+
"F2 panel first bar should start at 0, not carry over from F1. pos2end={:?}, pos2={:?}",
425+
pos2end_vals, pos2_vals
426+
);
427+
}
428+
429+
#[test]
430+
fn test_dodge_ignores_facet_columns_in_group_count() {
431+
// Dodge should compute n_groups per facet panel, not globally.
432+
// With fill=["X","Y"] and facet=["F1","F2"], dodge should see
433+
// 2 groups (X, Y) not 4 (X-F1, X-F2, Y-F1, Y-F2).
434+
//
435+
// With 2 groups and default width 0.9:
436+
// adjusted_width = 0.9 / 2 = 0.45
437+
// offsets: -0.225 (group X), +0.225 (group Y)
438+
//
439+
// If facet columns incorrectly inflate n_groups to 4:
440+
// adjusted_width = 0.9 / 4 = 0.225
441+
// offsets would be different (spread across 4 positions)
442+
let df = df! {
443+
"__ggsql_aes_pos1__" => ["A", "A", "A", "A"],
444+
"__ggsql_aes_pos2__" => [10.0, 20.0, 30.0, 40.0],
445+
"__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0],
446+
"__ggsql_aes_fill__" => ["X", "Y", "X", "Y"],
447+
"__ggsql_aes_facet1__" => ["F1", "F1", "F2", "F2"],
448+
}
449+
.unwrap();
450+
451+
let mut layer = crate::plot::Layer::new(Geom::bar());
452+
layer.mappings = {
453+
let mut m = Mappings::new();
454+
m.insert(
455+
"pos1",
456+
AestheticValue::standard_column("__ggsql_aes_pos1__"),
457+
);
458+
m.insert(
459+
"pos2",
460+
AestheticValue::standard_column("__ggsql_aes_pos2__"),
461+
);
462+
m.insert(
463+
"pos2end",
464+
AestheticValue::standard_column("__ggsql_aes_pos2end__"),
465+
);
466+
m.insert(
467+
"fill",
468+
AestheticValue::standard_column("__ggsql_aes_fill__"),
469+
);
470+
m.insert(
471+
"facet1",
472+
AestheticValue::standard_column("__ggsql_aes_facet1__"),
473+
);
474+
m
475+
};
476+
layer.partition_by = vec![
477+
"__ggsql_aes_fill__".to_string(),
478+
"__ggsql_aes_facet1__".to_string(),
479+
];
480+
layer.position = Position::dodge();
481+
layer.data_key = Some("__ggsql_layer_0__".to_string());
482+
483+
let mut spec = Plot::new();
484+
spec.scales.push(make_discrete_scale("pos1"));
485+
spec.scales.push(make_continuous_scale("pos2"));
486+
spec.facet = Some(Facet::new(FacetLayout::Wrap {
487+
variables: vec!["facet_var".to_string()],
488+
}));
489+
let mut data_map = HashMap::new();
490+
data_map.insert("__ggsql_layer_0__".to_string(), df);
491+
492+
spec.layers.push(layer);
493+
494+
apply_position_adjustments(&mut spec, &mut data_map).unwrap();
495+
496+
// With 2 groups (X, Y), adjusted_width should be 0.45
497+
let adjusted = spec.layers[0].adjusted_width.unwrap();
498+
assert!(
499+
(adjusted - 0.45).abs() < 0.001,
500+
"adjusted_width should be 0.45 (2 groups), got {} (facet columns inflated group count)",
501+
adjusted
502+
);
503+
}
325504
}

src/plot/layer/position/dodge.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
//! - If only pos2 is discrete → dodge vertically (pos2offset)
77
//! - If both are discrete → 2D grid dodge (both offsets, arranged in a grid)
88
9-
use super::{compute_dodge_offsets, is_continuous_scale, Layer, PositionTrait, PositionType};
9+
use super::{
10+
compute_dodge_offsets, is_continuous_scale, non_facet_partition_cols, Layer, PositionTrait,
11+
PositionType,
12+
};
1013
use crate::plot::types::{DefaultParamValue, ParamConstraint, ParamDefinition, ParameterValue};
1114
use crate::{naming, DataFrame, GgsqlError, Plot, Result};
1215
use polars::prelude::*;
@@ -159,8 +162,10 @@ fn apply_dodge_with_width(
159162
return Ok((df, None));
160163
}
161164

162-
// Compute group indices
163-
let group_info = match compute_group_indices(&df, &layer.partition_by)? {
165+
// Compute group indices, excluding facet columns so group count
166+
// reflects within-panel groups (not cross-panel composites)
167+
let group_cols = non_facet_partition_cols(&layer.partition_by, spec);
168+
let group_info = match compute_group_indices(&df, &group_cols)? {
164169
Some(info) => info,
165170
None => return Ok((df, None)),
166171
};

src/plot/layer/position/jitter.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
//! - `normal`: normal/Gaussian distribution with ~95% of points within the width
1616
1717
use super::{
18-
compute_dodge_offsets, compute_group_indices, is_continuous_scale, Layer, PositionTrait,
19-
PositionType,
18+
compute_dodge_offsets, compute_group_indices, is_continuous_scale, non_facet_partition_cols,
19+
Layer, PositionTrait, PositionType,
2020
};
2121
use crate::plot::types::{DefaultParamValue, ParamConstraint, ParamDefinition, ParameterValue};
2222
use crate::{naming, DataFrame, GgsqlError, Plot, Result};
@@ -491,9 +491,11 @@ fn apply_jitter(df: DataFrame, layer: &Layer, spec: &Plot) -> Result<DataFrame>
491491
let mut rng = rand::thread_rng();
492492
let n_rows = df.height();
493493

494-
// Compute group info for dodge-first behavior
494+
// Compute group info for dodge-first behavior, excluding facet columns
495+
// so group count reflects within-panel groups
496+
let group_cols = non_facet_partition_cols(&layer.partition_by, spec);
495497
let group_info = if dodge {
496-
compute_group_indices(&df, &layer.partition_by)?
498+
compute_group_indices(&df, &group_cols)?
497499
} else {
498500
None
499501
};

src/plot/layer/position/mod.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,32 @@ pub fn compute_dodge_offsets(
104104
}
105105
}
106106

107+
/// Filter facet columns out of partition_by for position adjustments that
108+
/// compute group indices (dodge, jitter).
109+
///
110+
/// Facet columns in partition_by inflate the group count — e.g., 2 fill groups
111+
/// across 2 facet panels would be seen as 4 composite groups instead of 2.
112+
/// Position adjustments should operate per-panel, so facet columns must be excluded.
113+
pub fn non_facet_partition_cols(partition_by: &[String], spec: &Plot) -> Vec<String> {
114+
let facet_cols: std::collections::HashSet<String> = spec
115+
.facet
116+
.as_ref()
117+
.map(|f| {
118+
f.layout
119+
.internal_facet_names()
120+
.into_iter()
121+
.map(|aes| crate::naming::aesthetic_column(&aes))
122+
.collect()
123+
})
124+
.unwrap_or_default();
125+
126+
partition_by
127+
.iter()
128+
.filter(|col| !facet_cols.contains(*col))
129+
.cloned()
130+
.collect()
131+
}
132+
107133
// Re-export position implementations
108134
pub use dodge::{compute_group_indices, Dodge, GroupIndices};
109135
pub use identity::Identity;

src/plot/layer/position/stack.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,22 +266,33 @@ fn apply_stack(df: DataFrame, layer: &Layer, spec: &Plot, mode: StackMode) -> Re
266266
// 2. stack_end_col = lag(stack_col, 1, 0) - the bar bottom/start (previous stack top)
267267
// The cumsum naturally stacks across the grouping column values
268268

269+
// Build the partition columns for .over(): group column + facet columns.
270+
// Facet columns must be included so stacking resets per facet panel,
271+
// matching ggplot2 where position adjustments are computed per-panel.
272+
let mut over_cols: Vec<Expr> = vec![col(&group_col)];
273+
if let Some(ref facet) = spec.facet {
274+
for aes in facet.layout.internal_facet_names() {
275+
let facet_col = naming::aesthetic_column(&aes);
276+
over_cols.push(col(&facet_col));
277+
}
278+
}
279+
269280
// Treat NA heights as 0 for stacking
270281
// Compute cumulative sums (shared by all modes)
271282
let lf = lf
272283
.with_column(col(&stack_col).fill_null(lit(0.0)).alias(&stack_col))
273284
.with_column(
274285
col(&stack_col)
275286
.cum_sum(false)
276-
.over([col(&group_col)])
287+
.over(&over_cols)
277288
.alias("__cumsum__"),
278289
)
279290
.with_column(
280291
col(&stack_col)
281292
.cum_sum(false)
282293
.shift(lit(1))
283294
.fill_null(lit(0.0))
284-
.over([col(&group_col)])
295+
.over(&over_cols)
285296
.alias("__cumsum_lag__"),
286297
);
287298

@@ -293,15 +304,15 @@ fn apply_stack(df: DataFrame, layer: &Layer, spec: &Plot, mode: StackMode) -> Re
293304
vec!["__cumsum__", "__cumsum_lag__"],
294305
),
295306
StackMode::Fill(target) => {
296-
let total = col(&stack_col).sum().over([col(&group_col)]);
307+
let total = col(&stack_col).sum().over(&over_cols);
297308
(
298309
(col("__cumsum__") / total.clone() * lit(target)).alias(&stack_col),
299310
(col("__cumsum_lag__") / total * lit(target)).alias(&stack_end_col),
300311
vec!["__cumsum__", "__cumsum_lag__"],
301312
)
302313
}
303314
StackMode::Center => {
304-
let half_total = col(&stack_col).sum().over([col(&group_col)]) / lit(2.0);
315+
let half_total = col(&stack_col).sum().over(&over_cols) / lit(2.0);
305316
(
306317
(col("__cumsum__") - half_total.clone()).alias(&stack_col),
307318
(col("__cumsum_lag__") - half_total).alias(&stack_end_col),

0 commit comments

Comments
 (0)