|
22 | 22 | #' @param style either `1` or `2`, indicating two different types of validation plots. |
23 | 23 | #' @param min_max a bool indicating if min-max normalization will be used to scale the testing output, RMSE, predictive mean and std from the |
24 | 24 | #' emulator. Defaults to `TRUE`. |
| 25 | +#' @param color a character string indicating the color map to use when `style = 2`: |
| 26 | +#' * `'magma'` (or `'A'`) |
| 27 | +#' * `'inferno'` (or `'B'`) |
| 28 | +#' * `'plasma'` (or '`C`') |
| 29 | +#' * `'viridis'` (or `'D'`) |
| 30 | +#' * `'cividis'` (or `'E'`) |
| 31 | +#' * `'rocket'` (or `'F'`) |
| 32 | +#' * `'mako'` (or `'G'`) |
| 33 | +#' * `'turbo'` (or `'H'`) |
| 34 | +#' |
| 35 | +#' Defaults to `'turbo'` (or `'H'`). |
25 | 36 | #' @param verb a bool indicating if the trace information on plotting will be printed during the function execution. |
26 | 37 | #' Defaults to `TRUE`. |
27 | 38 | #' @param force same as that of [validate()]. |
|
59 | 70 | #' @rdname plot |
60 | 71 | #' @method plot dgp |
61 | 72 | #' @export |
62 | | -plot.dgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_var', style = 1, min_max = TRUE, verb = TRUE, force = FALSE, cores = 1, threading = FALSE, ...) { |
| 73 | +plot.dgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_var', style = 1, min_max = TRUE, color = 'turbo', verb = TRUE, force = FALSE, cores = 1, threading = FALSE, ...) { |
63 | 74 | if ( style!=1&style!=2 ) stop("'style' must be either 1 or 2.", call. = FALSE) |
64 | 75 | if( !is.null(cores) ) { |
65 | 76 | cores <- as.integer(cores) |
@@ -131,11 +142,11 @@ plot.dgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean |
131 | 142 | dat[["y_validate"]] <- loo_res$y_train[,l] |
132 | 143 | dat[["std"]] <- loo_res$std[,l] |
133 | 144 | if ( min_max ){ |
134 | | - p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max) + |
| 145 | + p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max, color) + |
135 | 146 | ggplot2::ggtitle(sprintf("O%i: NRMSE = %.2f%%", l, loo_res$nrmse[l]*100)) + |
136 | 147 | ggplot2::theme(plot.title = ggplot2::element_text(size=10)) |
137 | 148 | } else { |
138 | | - p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max) + |
| 149 | + p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max, color) + |
139 | 150 | ggplot2::ggtitle(sprintf("O%i: RMSE = %.6f", l, loo_res$rmse[l])) + |
140 | 151 | ggplot2::theme(plot.title = ggplot2::element_text(size=10)) |
141 | 152 | } |
@@ -313,11 +324,11 @@ plot.dgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean |
313 | 324 | dat[["y_validate"]] <- oos_res$y_test[,l] |
314 | 325 | dat[["std"]] <- oos_res$std[,l] |
315 | 326 | if ( min_max ) { |
316 | | - p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max) + |
| 327 | + p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max, color) + |
317 | 328 | ggplot2::ggtitle(sprintf("O%i: NRMSE = %.2f%%", l, oos_res$nrmse[l]*100)) + |
318 | 329 | ggplot2::theme(plot.title = ggplot2::element_text(size=10)) |
319 | 330 | } else { |
320 | | - p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max) + |
| 331 | + p_list[[l]] <- plot_style_2(as.data.frame(dat), method, min_max, color) + |
321 | 332 | ggplot2::ggtitle(sprintf("O%i: RMSE = %.6f", l, oos_res$rmse[l])) + |
322 | 333 | ggplot2::theme(plot.title = ggplot2::element_text(size=10)) |
323 | 334 | } |
@@ -387,7 +398,7 @@ plot.dgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean |
387 | 398 | #' @rdname plot |
388 | 399 | #' @method plot lgp |
389 | 400 | #' @export |
390 | | -plot.lgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_var', style = 1, min_max = TRUE, verb = TRUE, force = FALSE, cores = 1, threading = FALSE, ...) { |
| 401 | +plot.lgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_var', style = 1, min_max = TRUE, color = 'turbo', verb = TRUE, force = FALSE, cores = 1, threading = FALSE, ...) { |
391 | 402 | if ( style!=1&style!=2 ) stop("'style' must be either 1 or 2.", call. = FALSE) |
392 | 403 | if( !is.null(cores) ) { |
393 | 404 | cores <- as.integer(cores) |
@@ -574,11 +585,11 @@ plot.lgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean |
574 | 585 | dat[["y_validate"]] <- y_test_list[[k]][,l] |
575 | 586 | dat[["std"]] <- oos_res$std[[k]][,l] |
576 | 587 | if ( min_max ) { |
577 | | - p_list[[counter]] <- plot_style_2(as.data.frame(dat), method, min_max) + |
| 588 | + p_list[[counter]] <- plot_style_2(as.data.frame(dat), method, min_max, color) + |
578 | 589 | ggplot2::ggtitle(sprintf("E%iO%i: NRMSE = %.2f%%", k, l, oos_res$nrmse[[k]][l]*100)) + |
579 | 590 | ggplot2::theme(plot.title = ggplot2::element_text(size=10)) |
580 | 591 | } else { |
581 | | - p_list[[counter]] <- plot_style_2(as.data.frame(dat), method, min_max) + |
| 592 | + p_list[[counter]] <- plot_style_2(as.data.frame(dat), method, min_max, color) + |
582 | 593 | ggplot2::ggtitle(sprintf("E%iO%i: RMSE = %.6f", k, l, oos_res$rmse[[k]][l])) + |
583 | 594 | ggplot2::theme(plot.title = ggplot2::element_text(size=10)) |
584 | 595 | } |
@@ -649,7 +660,7 @@ plot.lgp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean |
649 | 660 | #' @rdname plot |
650 | 661 | #' @method plot gp |
651 | 662 | #' @export |
652 | | -plot.gp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_var', style = 1, min_max = TRUE, verb = TRUE, force = FALSE, cores = 1, ...) { |
| 663 | +plot.gp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_var', style = 1, min_max = TRUE, color = 'turbo', verb = TRUE, force = FALSE, cores = 1, ...) { |
653 | 664 | if ( style!=1&style!=2 ) stop("'style' must be either 1 or 2.", call. = FALSE) |
654 | 665 | if( !is.null(cores) ) { |
655 | 666 | cores <- as.integer(cores) |
@@ -716,11 +727,11 @@ plot.gp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_ |
716 | 727 | dat[["y_validate"]] <- loo_res$y_train[,1] |
717 | 728 | dat[["std"]] <- loo_res$std[,1] |
718 | 729 | if ( min_max ) { |
719 | | - p <- plot_style_2(as.data.frame(dat), method, min_max) + |
| 730 | + p <- plot_style_2(as.data.frame(dat), method, min_max, color) + |
720 | 731 | ggplot2::ggtitle(sprintf('NRMSE = %.2f%%', loo_res$nrmse*100)) + |
721 | 732 | ggplot2::theme(plot.title = ggplot2::element_text(size=10)) |
722 | 733 | } else { |
723 | | - p <- plot_style_2(as.data.frame(dat), method, min_max) + |
| 734 | + p <- plot_style_2(as.data.frame(dat), method, min_max, color) + |
724 | 735 | ggplot2::ggtitle(sprintf('RMSE = %.6f', loo_res$rmse)) + |
725 | 736 | ggplot2::theme(plot.title = ggplot2::element_text(size=10)) |
726 | 737 | } |
@@ -870,11 +881,11 @@ plot.gp <- function(x, x_test = NULL, y_test = NULL, dim = NULL, method = 'mean_ |
870 | 881 | dat[["y_validate"]] <- oos_res$y_test[,1] |
871 | 882 | dat[["std"]] <- oos_res$std[,1] |
872 | 883 | if ( min_max ) { |
873 | | - p <- plot_style_2(as.data.frame(dat), method, min_max) + |
| 884 | + p <- plot_style_2(as.data.frame(dat), method, min_max, color) + |
874 | 885 | ggplot2::ggtitle(sprintf('NRMSE = %.2f%%', oos_res$nrmse*100)) + |
875 | 886 | ggplot2::theme(plot.title = ggplot2::element_text(size=10)) |
876 | 887 | } else { |
877 | | - p <- plot_style_2(as.data.frame(dat), method, min_max) + |
| 888 | + p <- plot_style_2(as.data.frame(dat), method, min_max, color) + |
878 | 889 | ggplot2::ggtitle(sprintf('RMSE = %.6f', oos_res$rmse)) + |
879 | 890 | ggplot2::theme(plot.title = ggplot2::element_text(size=10)) |
880 | 891 | } |
@@ -1025,7 +1036,7 @@ plot_style_1 <- function(dat, method, dim, isdup) { |
1025 | 1036 | return(p) |
1026 | 1037 | } |
1027 | 1038 |
|
1028 | | -plot_style_2 <- function(dat, method, min_max) { |
| 1039 | +plot_style_2 <- function(dat, method, min_max, color) { |
1029 | 1040 | y_min <- min(dat$y_validate) |
1030 | 1041 | y_max <- max(dat$y_validate) |
1031 | 1042 | std_min <- min(dat$std) |
@@ -1053,9 +1064,9 @@ plot_style_2 <- function(dat, method, min_max) { |
1053 | 1064 | ggplot2::geom_point(alpha=0.8, size=1.5) |
1054 | 1065 |
|
1055 | 1066 | if (isTRUE(min_max)){ |
1056 | | - p <- p + ggplot2::scale_colour_viridis_c("Normalized Predictive SD", option = 'turbo', breaks=seq(0,1,0.2), labels=c('0.0','0.2','0.4','0.6','0.8','1.0')) |
| 1067 | + p <- p + ggplot2::scale_colour_viridis_c("Normalized Predictive SD", option = color, breaks=seq(0,1,0.2), labels=c('0.0','0.2','0.4','0.6','0.8','1.0')) |
1057 | 1068 | } else { |
1058 | | - p <- p + ggplot2::scale_colour_viridis_c("Predictive SD", option = 'turbo') |
| 1069 | + p <- p + ggplot2::scale_colour_viridis_c("Predictive SD", option = color) |
1059 | 1070 | } |
1060 | 1071 |
|
1061 | 1072 | p <- p + |
|
0 commit comments