Skip to content

Commit 28933d3

Browse files
committed
updates on plot() and set_linked_idx()
1. updates on plot() for different color maps. 2. updates on set_linked_idx() for different emulators on the basis of a same emulator.
1 parent 1fdb1d8 commit 28933d3

4 files changed

Lines changed: 51 additions & 16 deletions

File tree

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
- Thanks to @tjmckinley, some R memory issues due to the underlying Python implementations are rectified.
77
- `set_seed()` function is added to ensure reproducible results from the package.
88
- A bug is fixed when candidate sets `x_cand` and `y_cand` are provided to `design()`.
9+
- One can choose different color palettes using the new argument `color` in `plot()` when `style = 2`.
10+
- `set_linked_idx()` allows constructions of different (D)GP emulators (in terms of different connections to the feeding layers) from a same (D)GP emulator.
911

1012
# dgpsi 2.1.6
1113

R/plot.R

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@
2222
#' @param style either `1` or `2`, indicating two different types of validation plots.
2323
#' @param min_max a bool indicating if min-max normalization will be used to scale the testing output, RMSE, predictive mean and std from the
2424
#' emulator. Defaults to `TRUE`.
25+
#' @param color a character string indicating the color map to use when `style = 2`:
26+
#' * `'magma'` (or `'A'`)
27+
#' * `'inferno'` (or `'B'`)
28+
#' * `'plasma'` (or '`C`')
29+
#' * `'viridis'` (or `'D'`)
30+
#' * `'cividis'` (or `'E'`)
31+
#' * `'rocket'` (or `'F'`)
32+
#' * `'mako'` (or `'G'`)
33+
#' * `'turbo'` (or `'H'`)
34+
#'
35+
#' Defaults to `'turbo'` (or `'H'`).
2536
#' @param verb a bool indicating if the trace information on plotting will be printed during the function execution.
2637
#' Defaults to `TRUE`.
2738
#' @param force same as that of [validate()].
@@ -59,7 +70,7 @@ NULL
5970
#' @rdname plot
6071
#' @method plot dgp
6172
#' @export
62-
plot.dgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_var', style = 1, min_max = TRUE, verb = TRUE, force = FALSE, cores = 1, threading = FALSE, ...) {
73+
plot.dgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_var', style = 1, min_max = TRUE, color = 'turbo', verb = TRUE, force = FALSE, cores = 1, threading = FALSE, ...) {
6374
if ( style!=1&style!=2 ) stop("'style' must be either 1 or 2.", call. = FALSE)
6475
if( !is.null(cores) ) {
6576
cores <- as.integer(cores)
@@ -131,11 +142,11 @@ plot.dgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean
131142
dat[["y_validate"]] <- loo_res$y_train[,l]
132143
dat[["std"]] <- loo_res$std[,l]
133144
if ( min_max ){
134-
p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max) +
145+
p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max, color) +
135146
ggplot2::ggtitle(sprintf("O%i: NRMSE = %.2f%%", l, loo_res$nrmse[l]*100)) +
136147
ggplot2::theme(plot.title = ggplot2::element_text(size=10))
137148
} else {
138-
p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max) +
149+
p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max, color) +
139150
ggplot2::ggtitle(sprintf("O%i: RMSE = %.6f", l, loo_res$rmse[l])) +
140151
ggplot2::theme(plot.title = ggplot2::element_text(size=10))
141152
}
@@ -313,11 +324,11 @@ plot.dgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean
313324
dat[["y_validate"]] <- oos_res$y_test[,l]
314325
dat[["std"]] <- oos_res$std[,l]
315326
if ( min_max ) {
316-
p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max) +
327+
p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max, color) +
317328
ggplot2::ggtitle(sprintf("O%i: NRMSE = %.2f%%", l, oos_res$nrmse[l]*100)) +
318329
ggplot2::theme(plot.title = ggplot2::element_text(size=10))
319330
} else {
320-
p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max) +
331+
p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max, color) +
321332
ggplot2::ggtitle(sprintf("O%i: RMSE = %.6f", l, oos_res$rmse[l])) +
322333
ggplot2::theme(plot.title = ggplot2::element_text(size=10))
323334
}
@@ -387,7 +398,7 @@ plot.dgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean
387398
#' @rdname plot
388399
#' @method plot lgp
389400
#' @export
390-
plot.lgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_var', style = 1, min_max = TRUE, verb = TRUE, force = FALSE, cores = 1, threading = FALSE, ...) {
401+
plot.lgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_var', style = 1, min_max = TRUE, color = 'turbo', verb = TRUE, force = FALSE, cores = 1, threading = FALSE, ...) {
391402
if ( style!=1&style!=2 ) stop("'style' must be either 1 or 2.", call. = FALSE)
392403
if( !is.null(cores) ) {
393404
cores <- as.integer(cores)
@@ -574,11 +585,11 @@ plot.lgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean
574585
dat[["y_validate"]] <- y_test_list[[k]][,l]
575586
dat[["std"]] <- oos_res$std[[k]][,l]
576587
if ( min_max ) {
577-
p_list[[counter]] <- plot_style_2(as.data.frame(dat), method, min_max) +
588+
p_list[[counter]] <- plot_style_2(as.data.frame(dat), method, min_max, color) +
578589
ggplot2::ggtitle(sprintf("E%iO%i: NRMSE = %.2f%%", k, l, oos_res$nrmse[[k]][l]*100)) +
579590
ggplot2::theme(plot.title = ggplot2::element_text(size=10))
580591
} else {
581-
p_list[[counter]] <- plot_style_2(as.data.frame(dat), method, min_max) +
592+
p_list[[counter]] <- plot_style_2(as.data.frame(dat), method, min_max, color) +
582593
ggplot2::ggtitle(sprintf("E%iO%i: RMSE = %.6f", k, l, oos_res$rmse[[k]][l])) +
583594
ggplot2::theme(plot.title = ggplot2::element_text(size=10))
584595
}
@@ -649,7 +660,7 @@ plot.lgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean
649660
#' @rdname plot
650661
#' @method plot gp
651662
#' @export
652-
plot.gp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_var', style = 1, min_max = TRUE, verb = TRUE, force = FALSE, cores = 1, ...) {
663+
plot.gp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_var', style = 1, min_max = TRUE, color = 'turbo', verb = TRUE, force = FALSE, cores = 1, ...) {
653664
if ( style!=1&style!=2 ) stop("'style' must be either 1 or 2.", call. = FALSE)
654665
if( !is.null(cores) ) {
655666
cores <- as.integer(cores)
@@ -716,11 +727,11 @@ plot.gp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_
716727
dat[["y_validate"]] <- loo_res$y_train[,1]
717728
dat[["std"]] <- loo_res$std[,1]
718729
if ( min_max ) {
719-
p <- plot_style_2(as.data.frame(dat), method, min_max) +
730+
p <- plot_style_2(as.data.frame(dat), method, min_max, color) +
720731
ggplot2::ggtitle(sprintf('NRMSE = %.2f%%', loo_res$nrmse*100)) +
721732
ggplot2::theme(plot.title = ggplot2::element_text(size=10))
722733
} else {
723-
p <- plot_style_2(as.data.frame(dat), method, min_max) +
734+
p <- plot_style_2(as.data.frame(dat), method, min_max, color) +
724735
ggplot2::ggtitle(sprintf('RMSE = %.6f', loo_res$rmse)) +
725736
ggplot2::theme(plot.title = ggplot2::element_text(size=10))
726737
}
@@ -870,11 +881,11 @@ plot.gp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_
870881
dat[["y_validate"]] <- oos_res$y_test[,1]
871882
dat[["std"]] <- oos_res$std[,1]
872883
if ( min_max ) {
873-
p <- plot_style_2(as.data.frame(dat), method, min_max) +
884+
p <- plot_style_2(as.data.frame(dat), method, min_max, color) +
874885
ggplot2::ggtitle(sprintf('NRMSE = %.2f%%', oos_res$nrmse*100)) +
875886
ggplot2::theme(plot.title = ggplot2::element_text(size=10))
876887
} else {
877-
p <- plot_style_2(as.data.frame(dat), method, min_max) +
888+
p <- plot_style_2(as.data.frame(dat), method, min_max, color) +
878889
ggplot2::ggtitle(sprintf('RMSE = %.6f', oos_res$rmse)) +
879890
ggplot2::theme(plot.title = ggplot2::element_text(size=10))
880891
}
@@ -1025,7 +1036,7 @@ plot_style_1 <- function(dat, method, dim, isdup) {
10251036
return(p)
10261037
}
10271038

1028-
plot_style_2 <- function(dat, method, min_max) {
1039+
plot_style_2 <- function(dat, method, min_max, color) {
10291040
y_min <- min(dat$y_validate)
10301041
y_max <- max(dat$y_validate)
10311042
std_min <- min(dat$std)
@@ -1053,9 +1064,9 @@ plot_style_2 <- function(dat, method, min_max) {
10531064
ggplot2::geom_point(alpha=0.8, size=1.5)
10541065

10551066
if (isTRUE(min_max)){
1056-
p <- p + ggplot2::scale_colour_viridis_c("Normalized Predictive SD", option = 'turbo', breaks=seq(0,1,0.2), labels=c('0.0','0.2','0.4','0.6','0.8','1.0'))
1067+
p <- p + ggplot2::scale_colour_viridis_c("Normalized Predictive SD", option = color, breaks=seq(0,1,0.2), labels=c('0.0','0.2','0.4','0.6','0.8','1.0'))
10571068
} else {
1058-
p <- p + ggplot2::scale_colour_viridis_c("Predictive SD", option = 'turbo')
1069+
p <- p + ggplot2::scale_colour_viridis_c("Predictive SD", option = color)
10591070
}
10601071

10611072
p <- p +

R/utils.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,12 @@ summary.lgp <- function(object, ...) {
333333
#' @export
334334
set_linked_idx <- function(object, idx) {
335335
idx <- reticulate::np_array(as.integer(idx - 1))
336+
object[['constructor_obj']] <- pkg.env$copy$deepcopy(object[['constructor_obj']])
337+
object[['emulator_obj']] <- pkg.env$copy$deepcopy(object[['emulator_obj']])
338+
object[['container_obj']] <- pkg.env$copy$deepcopy(object[['container_obj']])
336339
object$container_obj$set_local_input(idx)
340+
pkg.env$py_gc$collect()
341+
gc(full=T)
337342
return(object)
338343
}
339344

man/plot.Rd

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)