Skip to content

Commit 8637e82

Browse files
authored
Merge branch 'master' into fix/validate-chain-list-colnames-check
2 parents b8fe768 + 14c1c74 commit 8637e82

12 files changed

Lines changed: 179 additions & 22 deletions

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# bayesplot (development version)
22

33
* Fixed `validate_chain_list()` colnames check to compare all chains, not just the first two.
4+
* Added test verifying `legend_move("none")` behaves equivalently to `legend_none()`.
5+
* Added singleton-dimension edge-case tests for exported `_data()` functions.
6+
* Validate empty list and zero-row matrix inputs in `nuts_params.list()`.
7+
* Validate user-provided `pit` values in `ppc_loo_pit_data()` and `ppc_loo_pit_qq()`, rejecting non-numeric inputs, missing values, and values outside `[0, 1]`.
48
* New `show_marginal` argument to `ppd_*()` functions to show the PPD - the marginal predictive distribution by @mattansb (#425)
59
* `ppc_ecdf_overlay()`, `ppc_ecdf_overlay_grouped()`, and `ppd_ecdf_overlay()` now always use `geom_step()`. The `discrete` argument is deprecated.
610
* Fixed missing `drop = FALSE` in `nuts_params.CmdStanMCMC()`.

R/bayesplot-extractors.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,18 @@ nuts_params.stanreg <-
145145
#' @export
146146
#' @method nuts_params list
147147
nuts_params.list <- function(object, pars = NULL, ...) {
148+
if (length(object) == 0) {
149+
abort("'object' must be a non-empty list.")
150+
}
151+
148152
if (!all(sapply(object, is.matrix))) {
149153
abort("All list elements should be matrices.")
150154
}
151155

156+
if (any(vapply(object, nrow, integer(1)) == 0)) {
157+
abort("All matrices in the list must have at least one row.")
158+
}
159+
152160
dd <- lapply(object, dim)
153161
if (length(unique(dd)) != 1) {
154162
abort("All matrices in the list must have the same dimensions.")

R/ppc-loo.R

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,10 @@ ppc_loo_pit_data <-
302302
boundary_correction = TRUE,
303303
grid_len = 512) {
304304
if (!is.null(pit)) {
305-
stopifnot(is.numeric(pit), is_vector_or_1Darray(pit))
305+
pit <- validate_pit(pit)
306+
if (boundary_correction && length(pit) < 2L) {
307+
abort("At least 2 PIT values are required when 'boundary_correction' is TRUE.")
308+
}
306309
inform("'pit' specified so ignoring 'y','yrep','lw' if specified.")
307310
} else {
308311
suggested_package("rstantools")
@@ -348,7 +351,7 @@ ppc_loo_pit_qq <- function(y,
348351

349352
compare <- match.arg(compare)
350353
if (!is.null(pit)) {
351-
stopifnot(is.numeric(pit), is_vector_or_1Darray(pit))
354+
pit <- validate_pit(pit)
352355
inform("'pit' specified so ignoring 'y','yrep','lw' if specified.")
353356
} else {
354357
suggested_package("rstantools")
@@ -795,14 +798,6 @@ ppc_loo_ribbon <-
795798
# Generate boundary corrected values via a linear convolution using a
796799
# 1-D Gaussian window filter. This method uses the "reflection method"
797800
# to estimate these pvalues and helps speed up the code
798-
if (any(is.infinite(x))) {
799-
warn(paste(
800-
"Ignored", sum(is.infinite(x)),
801-
"Non-finite PIT values are invalid for KDE boundary correction method"
802-
))
803-
x <- x[is.finite(x)]
804-
}
805-
806801
if (grid_len < 100) {
807802
grid_len <- 100
808803
}
@@ -819,6 +814,10 @@ ppc_loo_ribbon <-
819814
# 1-D Convolution
820815
bc_pvals <- .linear_convolution(x, bw, grid_counts, grid_breaks, grid_len)
821816

817+
if (all(is.na(bc_pvals))) {
818+
abort("KDE boundary correction produced all NA values.")
819+
}
820+
822821
# Generate vector of x-axis values for plotting based on binned relative freqs
823822
n_breaks <- length(grid_breaks)
824823

tests/testthat/test-convenience-functions.R

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ test_that("legend_text returns correct theme object", {
120120
theme(legend.text = element_text(color = "purple", size = 16))
121121
)
122122
})
123+
test_that("legend_move('none') behaves like legend_none", {
124+
expect_equal(
125+
legend_move("none")$legend.position,
126+
legend_none()$legend.position,
127+
ignore_attr = TRUE
128+
)
129+
})
123130

124131
# axis and facet text --------------------------------------------------
125132
test_that("xaxis_text returns correct theme object", {
@@ -186,8 +193,6 @@ test_that("overlay_function returns the correct object", {
186193
a$constructor <- b$constructor <- NULL
187194
expect_equal(a, b, ignore_function_env = TRUE)
188195
})
189-
190-
191196
# tagged functions -------------------------------------------------------
192197

193198
test_that("as_tagged_function handles bare function (symbol)", {

tests/testthat/test-extractors.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ x <- list(cbind(a = 1:3, b = rnorm(3)), cbind(a = 1:3, b = rnorm(3)))
99

1010
# nuts_params and log_posterior methods -----------------------------------
1111
test_that("nuts_params.list throws errors", {
12+
expect_error(nuts_params.list(list()), "non-empty list")
13+
1214
x[[3]] <- c(a = 1:3, b = rnorm(3))
1315
expect_error(nuts_params.list(x), "list elements should be matrices")
1416

@@ -17,6 +19,20 @@ test_that("nuts_params.list throws errors", {
1719

1820
x[[3]] <- cbind(a = 1:4, b = rnorm(4))
1921
expect_error(nuts_params.list(x), "same dimensions")
22+
23+
zero_row <- list(cbind(a = numeric(0), b = numeric(0)))
24+
expect_error(nuts_params.list(zero_row), "at least one row")
25+
26+
zero_row_nonfirst <- list(cbind(a = 1:3, b = rnorm(3)), cbind(a = numeric(0), b = numeric(0)))
27+
expect_error(nuts_params.list(zero_row_nonfirst), "at least one row")
28+
})
29+
30+
test_that("nuts_params.list works with single-chain list", {
31+
single <- list(cbind(a = 1:3, b = rnorm(3)))
32+
np <- nuts_params.list(single)
33+
expect_identical(colnames(np), c("Chain", "Iteration", "Parameter", "Value"))
34+
expect_true(all(np$Chain == 1L))
35+
expect_equal(nrow(np), 6L)
2036
})
2137

2238
test_that("nuts_params.list parameter selection ok", {

tests/testthat/test-ppc-discrete.R

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,27 @@ test_that("ppc_bars_data includes all levels", {
7777
expect_equal(d3$h[2], 0, ignore_attr = TRUE)
7878
})
7979

80+
test_that("ppc_bars_data handles single observation and single draw", {
81+
y1 <- 2L
82+
yrep1 <- matrix(c(1L, 2L, 3L, 2L, 2L), ncol = 1)
83+
d <- ppc_bars_data(y1, yrep1)
84+
expect_s3_class(d, "data.frame")
85+
expect_equal(d$y_obs[d$x == 2], 1)
86+
87+
# single draw: interval collapses to a point
88+
y_s <- c(1L, 2L, 3L, 2L)
89+
yrep_s <- matrix(c(1L, 2L, 2L, 3L), nrow = 1)
90+
d2 <- ppc_bars_data(y_s, yrep_s)
91+
expect_equal(d2$l, d2$m, ignore_attr = TRUE)
92+
expect_equal(d2$m, d2$h, ignore_attr = TRUE)
93+
})
94+
95+
test_that("ppc_bars_data prob = 0 collapses interval to median", {
96+
d <- ppc_bars_data(y_ord, yrep_ord, prob = 0)
97+
expect_equal(d$l, d$m, ignore_attr = TRUE)
98+
expect_equal(d$m, d$h, ignore_attr = TRUE)
99+
})
100+
80101

81102
# rootograms -----------------------------------------------------------
82103
yrep3 <- matrix(yrep2, nrow = 5, ncol = ncol(yrep2), byrow = TRUE)

tests/testthat/test-ppc-distributions.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,14 @@ test_that("ppd_data handles a single replicate matrix", {
237237
expect_equal(d$value, c(11, 21))
238238
})
239239

240+
test_that("ppd_data handles single observation (single column)", {
241+
ypred <- matrix(c(1, 2, 3), ncol = 1)
242+
d <- ppd_data(ypred)
243+
expect_equal(nrow(d), 3)
244+
expect_true(all(d$y_id == 1))
245+
expect_equal(d$value, c(1, 2, 3))
246+
})
247+
240248

241249
# Visual tests -----------------------------------------------------------------
242250

tests/testthat/test-ppc-errors.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ test_that("ppc_error_data with group returns exact structure", {
8585
expect_equal(d$group[d$rep_id == 1], group)
8686
})
8787

88+
test_that("ppc_error_data handles single observation", {
89+
y1 <- 5
90+
yrep1 <- matrix(c(4, 6, 5), ncol = 1)
91+
d <- ppc_error_data(y1, yrep1)
92+
expect_equal(nrow(d), 3)
93+
expect_equal(d$value, y1 - yrep1[, 1])
94+
expect_true(all(d$y_obs == 5))
95+
})
96+
8897

8998
# Visual tests -----------------------------------------------------------------
9099

tests/testthat/test-ppc-intervals.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,19 @@ test_that("ppd_intervals_data + y_obs column same as ppc_intervals_data", {
7272
expect_equal(tibble::add_column(d_group2, y_obs = d_group$y_obs, .after = "y_id"), d_group)
7373
})
7474

75+
test_that("ppd_intervals_data handles single observation and single draw", {
76+
yrep_1obs <- matrix(rnorm(25), ncol = 1)
77+
d <- ppd_intervals_data(yrep_1obs)
78+
expect_equal(nrow(d), 1)
79+
expect_true(d$ll <= d$l && d$l <= d$m && d$m <= d$h && d$h <= d$hh)
80+
81+
# single draw: all quantiles collapse to the value
82+
yrep_1draw <- matrix(rnorm(10), nrow = 1)
83+
d2 <- ppd_intervals_data(yrep_1draw)
84+
expect_equal(d2$ll, d2$m)
85+
expect_equal(d2$hh, d2$m)
86+
})
87+
7588
test_that("ppc_intervals_data does math correctly", {
7689
d <- ppc_intervals_data(y, yrep, prob = .4, prob_outer = .8)
7790
qs <- unname(quantile(yrep[, 1], c(.1, .3, .5, .7, .9)))

tests/testthat/test-ppc-loo.R

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,50 @@ test_that("ppc_loo_pit_overlay works with boundary_correction=FALSE", {
5959
expect_gg(p1)
6060
})
6161

62-
test_that(".kde_correction warns when PIT values are non-finite", {
63-
set.seed(123)
64-
pit_vals <- c(stats::runif(500), Inf)
65-
expect_warning(
66-
out <- .kde_correction(pit_vals, bw = "nrd0", grid_len = 128),
67-
"Non-finite PIT values are invalid"
62+
test_that("ppc_loo_pit_data validates user-provided pit values", {
63+
expect_error(
64+
ppc_loo_pit_data(pit = c(0.5, Inf)),
65+
"between 0 and 1"
66+
)
67+
expect_error(
68+
ppc_loo_pit_data(pit = c(-1, 0.5)),
69+
"between 0 and 1"
70+
)
71+
expect_error(
72+
ppc_loo_pit_data(pit = c(0.5, NA)),
73+
"NAs not allowed"
74+
)
75+
expect_error(
76+
ppc_loo_pit_data(pit = "not numeric"),
77+
"is.numeric"
78+
)
79+
expect_error(
80+
ppc_loo_pit_data(pit = c(Inf, -Inf, Inf)),
81+
"between 0 and 1"
82+
)
83+
expect_error(
84+
ppc_loo_pit_data(pit = 0.5, boundary_correction = TRUE),
85+
"At least 2 PIT values"
86+
)
87+
})
88+
89+
test_that("ppc_loo_pit_qq validates user-provided pit values", {
90+
expect_error(
91+
ppc_loo_pit_qq(pit = c(0.5, Inf)),
92+
"between 0 and 1"
93+
)
94+
expect_error(
95+
ppc_loo_pit_qq(pit = c(-1, 0.5)),
96+
"between 0 and 1"
97+
)
98+
expect_error(
99+
ppc_loo_pit_qq(pit = c(0.5, NA)),
100+
"NAs not allowed"
101+
)
102+
expect_error(
103+
ppc_loo_pit_qq(pit = "not numeric"),
104+
"is.numeric"
68105
)
69-
expect_type(out, "list")
70-
expect_true(all(c("xs", "bc_pvals") %in% names(out)))
71-
expect_equal(length(out$xs), 128)
72-
expect_equal(length(out$bc_pvals), 128)
73106
})
74107

75108
test_that("ppc_loo_pit_qq returns ggplot object", {
@@ -399,3 +432,10 @@ test_that("ppc_loo_pit_data returns the expected structure for both boundary mod
399432
expect_equal(nrow(yrep_rows), grid_len * n_samples)
400433
expect_false(anyNA(d_bc$x))
401434
})
435+
436+
test_that("ppc_loo_pit_data works with a single pit value", {
437+
d <- suppressMessages(ppc_loo_pit_data(pit = 0.5, boundary_correction = FALSE, samples = 3))
438+
y_rows <- d[d$is_y, ]
439+
expect_equal(nrow(y_rows), 1)
440+
expect_equal(y_rows$value, 0.5)
441+
})

0 commit comments

Comments
 (0)