Skip to content

Commit c41c058

Browse files
committed
Let cv_mb returns biweight robust mean metrics directly. Update tests accordingly. Add desription for outputs.
1 parent 6b6c7c8 commit c41c058

2 files changed

Lines changed: 21 additions & 9 deletions

File tree

R/mass-balance-optim.R

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ lsq_mb <- function(hat, obs, lambda, mus, sigmas, log.seasons, log.ann, N, sInd)
5252
#'
5353
#' This is a wrapper for `lsq_mb()`. It first calculates `hat`, then calls `lsq_mb()`.
5454
#' This is used in `optim()`, so it returns a scalar.
55-
#' @param beta Parameters
55+
#' @param beta Parameters
5656
#' @param X Inputs, must have columns of 1 added
5757
#' @param Y Observed Dry, Wet, and Annual log-transformed flows
5858
#' @inheritParams lsq_mb
@@ -209,15 +209,22 @@ mb_reconstruction <- function(instQ, pc.list, start.year, lambda = 1,
209209
#' @inheritParams mb_reconstruction
210210
#' @param pc.list List of PC matrices
211211
#' @param cv.folds A list containing the cross validation folds
212-
#' @param return.type If 'mb', only the objective function value is returned. If 'metrics', all metrics are returned. If 'Q', all Q predictions are returned.
212+
#' @param return.type The type of results to be returned. Several types are possible to suit multiple use cases.
213+
#' \describe{
214+
#' \item{`fval`}{Only the objective function value (penalized least squares) is returned; this is useful for the outer optimization for site selection.}
215+
#' \item{`metrics`}{all performance metrics are returned.}
216+
#' \item{`metric means`}{the Tukey's biweight robust mean of each metric is returned.}
217+
#' \item{`Q`}{The predicted flow in each cross-validation run is returned. This is the most basic output, so that you can use it to calculate other metrics that are not provided by the package.}
218+
#' }
219+
#' @return A `data.table` containing cross-validation results (metrics, fval, or metric means) for each target.
213220
#' @examples
214221
#' cvFolds <- make_Z(1922:2003, nRuns = 50, frac = 0.25, contiguous = TRUE)
215222
#' cv <- cv_mb(p1Seasonal, pc3seasons, cvFolds, 1750, log.trans = 1:3, return.type = 'metrics')
216223
#' @export
217224
cv_mb <- function(instQ, pc.list, cv.folds, start.year,
218225
lambda = 1,
219226
log.trans = NULL, force.standardize = FALSE,
220-
return.type = c('mb', 'metrics', 'Q')) {
227+
return.type = c('fval', 'metrics', 'metric means', 'Q')) {
221228

222229
# Setup
223230
years <- start.year:max(instQ$year)
@@ -345,15 +352,15 @@ cv_mb <- function(instQ, pc.list, cv.folds, start.year,
345352

346353
fval <- lsq_mb(hat[valInd], Y[valInd], lambda, cm, csd, log.seasons, log.ann, N, sInd)
347354

348-
if (return.type == 'mb') {
355+
if (return.type == 'fval') {
349356
ans <- fval
350357
} else {
351358
Qcv <- merge(
352359
back_trans(hat, yearsInst, cm, csd, log.trans, N, seasons),
353360
instQ, by = c('year', 'season'))
354361
if (return.type == 'Q') {
355362
ans <- Qcv
356-
} else {
363+
} else { # Metrics
357364
metrics <- Qcv[, as.data.table(t(calculate_metrics(Q, Qa, z))), by = season]
358365
metrics[, fval := fval, by = season]
359366
ans <- metrics
@@ -365,9 +372,14 @@ cv_mb <- function(instQ, pc.list, cv.folds, start.year,
365372
# Run cv ---------------------------------------------------
366373

367374
if (return.type == 'mb') { # A vector of fval
368-
unlist(lapply(cv.folds, one_cv), use.names = FALSE)
375+
out <- unlist(lapply(cv.folds, one_cv), use.names = FALSE)
369376
} else { # A data.table of all metrics or all reps
370-
rbindlist(lapply(cv.folds, one_cv), idcol = 'rep')
377+
outReps <- rbindlist(lapply(cv.folds, one_cv), idcol = 'rep')
378+
if (return.type == 'metric means') {
379+
out <- outReps[, lapply(.SD, tbrm), .SDcols = c('R2', 'RE', 'CE', 'nRMSE', 'KGE'), by = season]
380+
} else out <- outReps
371381
}
382+
out[, season := factor(season, seasons)]
383+
out[order(season)]
372384
}
373385

tests/testthat/test-p1.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ context("P1")
33
fit <- mb_reconstruction(p1Seasonal, pc3seasons, 1750, log.trans = 1:3)
44
set.seed(24)
55
cvFolds <- make_Z(1922:2003, nRuns = 50, frac = 0.25, contiguous = TRUE)
6-
cv <- cv_mb(p1Seasonal, pc3seasons, cvFolds, 1750, log.trans = 1:3, return.type = 'metrics')
6+
cv <- cv_mb(p1Seasonal, pc3seasons, cvFolds, 1750, log.trans = 1:3, return.type = 'metric means')
77

88
test_that("P.1 reconstruction produces data.table output", {
99
expect_is(fit, 'data.table')
@@ -18,5 +18,5 @@ test_that("P.1 cross-validation produces data.table output", {
1818
})
1919

2020
test_that("P.1 cross-validation is numerically correct", {
21-
expect_equal(cv$R2[1], 0.4111, tolerance = 1e-4)
21+
expect_equal(cv$R2[1], 0.4588, tolerance = 1e-4)
2222
})

0 commit comments

Comments
 (0)