Skip to content

Commit 48504dd

Browse files
committed
Bug fix of prune()
Fixed a bug for prune() if the emulator has a likelihood layer
1 parent b87bd0f commit 48504dd

2 files changed

Lines changed: 35 additions & 4 deletions

File tree

R/utils.R

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -774,10 +774,15 @@ trace_plot <- function(object, layer = NULL, node = 1) {
774774
#' @return An updated `object` that could be an instance of `gp`, `dgp`, or `bundle` (of GP emulators) class.
775775
#'
776776
#' @note
777-
#' The function requires a DGP emulator that has been trained with a dataset comprising a minimum size equal to `min_size` in `control`.
777+
#' * The function requires a DGP emulator that has been trained with a dataset comprising a minimum size equal to `min_size` in `control`.
778778
#' If the training dataset size is smaller than this, it is suggested to enrich the design of the DGP emulator and prune its
779779
#' structure dynamically using the `design()` function. Depending on the design of the DGP emulator, the static pruning may not be accurate.
780780
#' It is thus suggested to implement dynamic pruning as a part of the sequential design via `design()`.
781+
#' * The following slots:
782+
#' - `loo` and `oos` created by [validate()]; and
783+
#' - `results` created by [predict()];
784+
#'
785+
#' in `object` will be removed and not contained in the returned object.
781786
#'
782787
#' @details See further examples and tutorials at <https://mingdeyu.github.io/dgpsi-R/>.
783788
#' @examples
@@ -869,6 +874,7 @@ prune <- function(object, control = list(), verb = TRUE) {
869874
}
870875

871876
is.finish <- FALSE
877+
cropping_times <- 0
872878
while (!is.finish){
873879
crop_id_list <- create_drop_list(object)
874880
r2 <- object$constructor_obj$aggregate_r2()
@@ -881,12 +887,28 @@ prune <- function(object, control = list(), verb = TRUE) {
881887
if (N_cropped!=0) {
882888
object <- copy_in_design(object)
883889
object <- crop(object, crop_id_list, refit_cores = as.integer(1), verb = verb)
884-
if ( !inherits(object,"dgp") ) is.finish <- TRUE
890+
if ( !inherits(object,"dgp") ) {
891+
is.finish <- TRUE
892+
} else {
893+
n_layer <- object$constructor_obj$n_layer
894+
if (object$constructor_obj$all_layer[[n_layer]][[1]]$type!='gp') {
895+
n_layer <- n_layer - 1
896+
if (n_layer == 1) is.finish <- TRUE
897+
}
898+
}
899+
cropping_times <- cropping_times + 1
885900
} else {
886-
if (verb) message("No more GP nodes can be pruned.", appendLF = FALSE)
887901
is.finish <- TRUE
888902
}
889903
}
904+
if (cropping_times == 0) {
905+
if (verb) message("No GP nodes can be pruned.", appendLF = FALSE)
906+
} else {
907+
if ('loo' %in% names(object)) object[['loo']] <- NULL
908+
if ('oos' %in% names(object)) object[['oos']] <- NULL
909+
if ('results' %in% names(object)) object[['results']] <- NULL
910+
if (verb) message(" * No more GP nodes can be pruned.", appendLF = FALSE)
911+
}
890912
return(object)
891913
}
892914

man/prune.Rd

Lines changed: 10 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)