From 07697a2d0ebc0c19bba3f176c1625def5e87f404 Mon Sep 17 00:00:00 2001 From: jgabry Date: Mon, 17 Sep 2018 15:36:15 +0200 Subject: [PATCH 01/17] loo_compare: same output regardless of number of models --- NAMESPACE | 2 + R/compare.R | 38 +++-------- R/loo_compare.R | 155 +++++++++++++++++++++++++++++++++++++++++++++ man/compare.Rd | 11 ---- man/loo_compare.Rd | 93 +++++++++++++++++++++++++++ 5 files changed, 260 insertions(+), 39 deletions(-) create mode 100644 R/loo_compare.R create mode 100644 man/loo_compare.Rd diff --git a/NAMESPACE b/NAMESPACE index d6d43242..6d4c43a3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -11,6 +11,7 @@ S3method(dim,waic) S3method(loo,"function") S3method(loo,array) S3method(loo,matrix) +S3method(loo_compare,default) S3method(loo_model_weights,default) S3method(plot,loo) S3method(plot,psis) @@ -51,6 +52,7 @@ export(loo) export(loo.array) export(loo.function) export(loo.matrix) +export(loo_compare) export(loo_i) export(loo_model_weights) export(loo_model_weights.default) diff --git a/R/compare.R b/R/compare.R index 5923b11f..c9b341ec 100644 --- a/R/compare.R +++ b/R/compare.R @@ -120,24 +120,6 @@ compare <- function(..., x = list()) { } } -#' @rdname compare -#' @export -#' @param digits For the print method only, the number of digits to use when -#' printing. -#' @param simplify For the print method only, should only the essential columns -#' of the summary matrix be printed when comparing more than two models? The -#' entire matrix is always returned, but by default only the most important -#' columns are printed. -print.compare.loo <- function(x, ..., digits = 1, simplify = TRUE) { - xcopy <- x - if (NCOL(xcopy) >= 2 && simplify) { - patts <- "^elpd_|^se_diff|^p_|^waic$|^looic$" - xcopy <- xcopy[, grepl(patts, colnames(xcopy))] - } - print(.fr(xcopy, digits), quote = FALSE) - invisible(x) -} - # internal ---------------------------------------------------------------- @@ -155,13 +137,13 @@ compare_two_models <- function(loo_a, loo_b, return = c("elpd_diff", "se"), chec structure(comp, class = "compare.loo") } -elpd_diffs <- function(loo_a, loo_b) { - pt_a <- loo_a$pointwise - pt_b <- loo_b$pointwise - elpd <- grep("^elpd", colnames(pt_a)) - pt_b[, elpd] - pt_a[, elpd] -} -se_elpd_diff <- function(diffs) { - N <- length(diffs) - sqrt(N) * sd(diffs) -} +# elpd_diffs <- function(loo_a, loo_b) { +# pt_a <- loo_a$pointwise +# pt_b <- loo_b$pointwise +# elpd <- grep("^elpd", colnames(pt_a)) +# pt_b[, elpd] - pt_a[, elpd] +# } +# se_elpd_diff <- function(diffs) { +# N <- length(diffs) +# sqrt(N) * sd(diffs) +# } diff --git a/R/loo_compare.R b/R/loo_compare.R new file mode 100644 index 00000000..91e86750 --- /dev/null +++ b/R/loo_compare.R @@ -0,0 +1,155 @@ +#' Model comparison +#' +#' Compare fitted models on LOO or WAIC. +#' +#' @export +#' @param ... At least two objects returned by \code{\link{loo}} (or +#' \code{\link{waic}}). +#' @param x A list of at least two objects returned by \code{\link{loo}} (or +#' \code{\link{waic}}). This argument can be used as an alternative to +#' specifying the objects in \code{...}. +#' +#' @return A vector or matrix with class \code{'compare.loo'} that has its own +#' print method. If exactly two objects are provided in \code{...} or +#' \code{x}, then the difference in expected predictive accuracy and the +#' standard error of the difference are returned. If more than two objects are +#' provided then a matrix of summary information is returned (see +#' \strong{Details}). +#' +#' @details +#' When comparing two fitted models, we can estimate the difference in their +#' expected predictive accuracy by the difference in \code{elpd_loo} or +#' \code{elpd_waic} (or multiplied by \eqn{-2}, if desired, to be on the +#' deviance scale). +#' +#' \emph{When that difference, \code{elpd_diff}, is positive then the expected +#' predictive accuracy for the second model is higher. A negative +#' \code{elpd_diff} favors the first model.} +#' +#' When using \code{compare()} with more than two models, the values in the +#' \code{elpd_diff} and \code{se_diff} columns of the returned matrix are +#' computed by making pairwise comparisons between each model and the model +#' with the best ELPD (i.e., the model in the first row). +#' Although the \code{elpd_diff} column is equal to the difference in +#' \code{elpd_loo}, do not expect the \code{se_diff} column to be equal to the +#' the difference in \code{se_elpd_loo}. +#' +#' To compute the standard error of the difference in ELPD we use a +#' paired estimate to take advantage of the fact that the same set of \eqn{N} +#' data points was used to fit both models. These calculations should be most +#' useful when \eqn{N} is large, because then non-normality of the +#' distribution is not such an issue when estimating the uncertainty in these +#' sums. These standard errors, for all their flaws, should give a better +#' sense of uncertainty than what is obtained using the current standard +#' approach of comparing differences of deviances to a Chi-squared +#' distribution, a practice derived for Gaussian linear models or +#' asymptotically, and which only applies to nested models in any case. +#' +#' @template loo-and-psis-references +#' +#' @examples +#' \dontrun{ +#' loo1 <- loo(log_lik1) +#' loo2 <- loo(log_lik2) +#' print(loo_compare(loo1, loo2), digits = 3) +#' print(loo_compare(x = list(loo1, loo2))) +#' +#' waic1 <- waic(log_lik1) +#' waic2 <- waic(log_lik2) +#' loo_compare(waic1, waic2) +#' } +#' +loo_compare <- function(x, ...) { + UseMethod("loo_compare") +} + +#' @rdname loo_compare +#' @export +loo_compare.default <- function(x, ...) { + if (is.loo(x)) { + dots <- list(...) + dots <- c(list(x), dots) + nms <- c(as.character(match.call(expand.dots = TRUE))[-1L]) + } else { + if (!is.list(x) || !length(x)) { + stop("'x' must be a list if not a 'loo' object.", call. = FALSE) + } + if (length(list(...))) { + stop("If 'x' is a list then '...' should not be specified.", call. = FALSE) + } + dots <- x + nms <- names(dots) + if (!length(nms)) { + nms <- paste0("model", seq_along(dots)) + } + } + + if (!all(sapply(dots, is.loo))) { + stop("All inputs should have class 'loo'.") + } + if (length(dots) <= 1L) { + stop("'loo_compare' requires at least two models.") + } else { + Ns <- sapply(dots, function(x) nrow(x$pointwise)) + if (!all(Ns == Ns[1L])) { + stop("Not all models have the same number of data points.", call. = FALSE) + } + + x <- sapply(dots, function(x) { + est <- x$estimates + setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))) ) + }) + + colnames(x) <- nms + rnms <- rownames(x) + comp <- x + ord <- order(x[grep("^elpd", rnms), ], decreasing = TRUE) + comp <- t(comp)[ord, ] + patts <- c("elpd", "p_", "^waic$|^looic$", "^se_waic$|^se_looic$") + col_ord <- unlist(sapply(patts, function(p) grep(p, colnames(comp))), + use.names = FALSE) + comp <- comp[, col_ord] + + # compute elpd_diff and se_elpd_diff relative to best model + rnms <- rownames(comp) + diffs <- mapply(elpd_diffs, dots[ord[1]], dots[ord]) + elpd_diff <- apply(diffs, 2, sum) + se_diff <- apply(diffs, 2, se_elpd_diff) + comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp) + rownames(comp) <- rnms + class(comp) <- c("compare.loo", class(comp)) + comp + } +} + +#' @rdname loo_compare +#' @export +#' @param digits For the print method only, the number of digits to use when +#' printing. +#' @param simplify For the print method only, should only the essential columns +#' of the summary matrix be printed? The entire matrix is always returned, but +#' by default only the most important columns are printed. +print.compare.loo <- function(x, ..., digits = 1, simplify = TRUE) { + xcopy <- x + if (simplify) { + patts <- "^elpd_|^se_diff|^p_|^waic$|^looic$" + xcopy <- xcopy[, grepl(patts, colnames(xcopy))] + } + print(.fr(xcopy, digits), quote = FALSE) + invisible(x) +} + + + +# internal ---------------------------------------------------------------- + +elpd_diffs <- function(loo_a, loo_b) { + pt_a <- loo_a$pointwise + pt_b <- loo_b$pointwise + elpd <- grep("^elpd", colnames(pt_a)) + pt_b[, elpd] - pt_a[, elpd] +} +se_elpd_diff <- function(diffs) { + N <- length(diffs) + sqrt(N) * sd(diffs) +} diff --git a/man/compare.Rd b/man/compare.Rd index 6acb10f0..cb30a7c2 100644 --- a/man/compare.Rd +++ b/man/compare.Rd @@ -2,12 +2,9 @@ % Please edit documentation in R/compare.R \name{compare} \alias{compare} -\alias{print.compare.loo} \title{Model comparison} \usage{ compare(..., x = list()) - -\method{print}{compare.loo}(x, ..., digits = 1, simplify = TRUE) } \arguments{ \item{...}{At least two objects returned by \code{\link{loo}} (or @@ -16,14 +13,6 @@ compare(..., x = list()) \item{x}{A list of at least two objects returned by \code{\link{loo}} (or \code{\link{waic}}). This argument can be used as an alternative to specifying the objects in \code{...}.} - -\item{digits}{For the print method only, the number of digits to use when -printing.} - -\item{simplify}{For the print method only, should only the essential columns -of the summary matrix be printed when comparing more than two models? The -entire matrix is always returned, but by default only the most important -columns are printed.} } \value{ A vector or matrix with class \code{'compare.loo'} that has its own diff --git a/man/loo_compare.Rd b/man/loo_compare.Rd new file mode 100644 index 00000000..efb2560f --- /dev/null +++ b/man/loo_compare.Rd @@ -0,0 +1,93 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/loo_compare.R +\name{loo_compare} +\alias{loo_compare} +\alias{loo_compare.default} +\alias{print.compare.loo} +\title{Model comparison} +\usage{ +loo_compare(x, ...) + +\method{loo_compare}{default}(x, ...) + +\method{print}{compare.loo}(x, ..., digits = 1, simplify = TRUE) +} +\arguments{ +\item{x}{A list of at least two objects returned by \code{\link{loo}} (or +\code{\link{waic}}). This argument can be used as an alternative to +specifying the objects in \code{...}.} + +\item{...}{At least two objects returned by \code{\link{loo}} (or +\code{\link{waic}}).} + +\item{digits}{For the print method only, the number of digits to use when +printing.} + +\item{simplify}{For the print method only, should only the essential columns +of the summary matrix be printed? The entire matrix is always returned, but +by default only the most important columns are printed.} +} +\value{ +A vector or matrix with class \code{'compare.loo'} that has its own + print method. If exactly two objects are provided in \code{...} or + \code{x}, then the difference in expected predictive accuracy and the + standard error of the difference are returned. If more than two objects are + provided then a matrix of summary information is returned (see + \strong{Details}). +} +\description{ +Compare fitted models on LOO or WAIC. +} +\details{ +When comparing two fitted models, we can estimate the difference in their + expected predictive accuracy by the difference in \code{elpd_loo} or + \code{elpd_waic} (or multiplied by \eqn{-2}, if desired, to be on the + deviance scale). + + \emph{When that difference, \code{elpd_diff}, is positive then the expected + predictive accuracy for the second model is higher. A negative + \code{elpd_diff} favors the first model.} + + When using \code{compare()} with more than two models, the values in the + \code{elpd_diff} and \code{se_diff} columns of the returned matrix are + computed by making pairwise comparisons between each model and the model + with the best ELPD (i.e., the model in the first row). + Although the \code{elpd_diff} column is equal to the difference in + \code{elpd_loo}, do not expect the \code{se_diff} column to be equal to the + the difference in \code{se_elpd_loo}. + + To compute the standard error of the difference in ELPD we use a + paired estimate to take advantage of the fact that the same set of \eqn{N} + data points was used to fit both models. These calculations should be most + useful when \eqn{N} is large, because then non-normality of the + distribution is not such an issue when estimating the uncertainty in these + sums. These standard errors, for all their flaws, should give a better + sense of uncertainty than what is obtained using the current standard + approach of comparing differences of deviances to a Chi-squared + distribution, a practice derived for Gaussian linear models or + asymptotically, and which only applies to nested models in any case. +} +\examples{ +\dontrun{ +loo1 <- loo(log_lik1) +loo2 <- loo(log_lik2) +print(loo_compare(loo1, loo2), digits = 3) +print(loo_compare(x = list(loo1, loo2))) + +waic1 <- waic(log_lik1) +waic2 <- waic(log_lik2) +loo_compare(waic1, waic2) +} + +} +\references{ +Vehtari, A., Gelman, A., and Gabry, J. (2017a). Practical + Bayesian model evaluation using leave-one-out cross-validation and WAIC. + \emph{Statistics and Computing}. 27(5), 1413--1432. + doi:10.1007/s11222-016-9696-4. + (\href{http://link.springer.com/article/10.1007\%2Fs11222-016-9696-4}{published + version}, \href{http://arxiv.org/abs/1507.04544}{arXiv preprint}). + +Vehtari, A., Gelman, A., and Gabry, J. (2017b). Pareto smoothed + importance sampling. arXiv preprint: \url{http://arxiv.org/abs/1507.02646/} +} From ca24e36c09ba72030f7aa2be2ba1a70af7eb560a Mon Sep 17 00:00:00 2001 From: jgabry Date: Mon, 17 Sep 2018 15:53:26 +0200 Subject: [PATCH 02/17] update print.compare.loo [ci skip] --- R/compare.R | 1 + R/loo_compare.R | 18 +++++------------- man/loo_compare.Rd | 15 ++++----------- 3 files changed, 10 insertions(+), 24 deletions(-) diff --git a/R/compare.R b/R/compare.R index c9b341ec..c388c418 100644 --- a/R/compare.R +++ b/R/compare.R @@ -60,6 +60,7 @@ #' } #' compare <- function(..., x = list()) { + .Deprecated("loo_compare") dots <- list(...) if (length(dots)) { if (length(x)) { diff --git a/R/loo_compare.R b/R/loo_compare.R index 91e86750..33a70f6f 100644 --- a/R/loo_compare.R +++ b/R/loo_compare.R @@ -3,18 +3,11 @@ #' Compare fitted models on LOO or WAIC. #' #' @export -#' @param ... At least two objects returned by \code{\link{loo}} (or -#' \code{\link{waic}}). -#' @param x A list of at least two objects returned by \code{\link{loo}} (or -#' \code{\link{waic}}). This argument can be used as an alternative to -#' specifying the objects in \code{...}. +#' @param x An object of class \code{"loo"} or a list of such objects. +#' @param ... Additional objects of class \code{"loo"}. #' -#' @return A vector or matrix with class \code{'compare.loo'} that has its own -#' print method. If exactly two objects are provided in \code{...} or -#' \code{x}, then the difference in expected predictive accuracy and the -#' standard error of the difference are returned. If more than two objects are -#' provided then a matrix of summary information is returned (see -#' \strong{Details}). +#' @return A matrix with class \code{'compare.loo'} that has its own +#' print method. See \strong{Details}. #' #' @details #' When comparing two fitted models, we can estimate the difference in their @@ -132,8 +125,7 @@ loo_compare.default <- function(x, ...) { print.compare.loo <- function(x, ..., digits = 1, simplify = TRUE) { xcopy <- x if (simplify) { - patts <- "^elpd_|^se_diff|^p_|^waic$|^looic$" - xcopy <- xcopy[, grepl(patts, colnames(xcopy))] + xcopy <- xcopy[, c("elpd_diff", "se_diff")] } print(.fr(xcopy, digits), quote = FALSE) invisible(x) diff --git a/man/loo_compare.Rd b/man/loo_compare.Rd index efb2560f..f0ad1fd1 100644 --- a/man/loo_compare.Rd +++ b/man/loo_compare.Rd @@ -13,12 +13,9 @@ loo_compare(x, ...) \method{print}{compare.loo}(x, ..., digits = 1, simplify = TRUE) } \arguments{ -\item{x}{A list of at least two objects returned by \code{\link{loo}} (or -\code{\link{waic}}). This argument can be used as an alternative to -specifying the objects in \code{...}.} +\item{x}{An object of class \code{"loo"} or a list of such objects.} -\item{...}{At least two objects returned by \code{\link{loo}} (or -\code{\link{waic}}).} +\item{...}{Additional objects of class \code{"loo"}.} \item{digits}{For the print method only, the number of digits to use when printing.} @@ -28,12 +25,8 @@ of the summary matrix be printed? The entire matrix is always returned, but by default only the most important columns are printed.} } \value{ -A vector or matrix with class \code{'compare.loo'} that has its own - print method. If exactly two objects are provided in \code{...} or - \code{x}, then the difference in expected predictive accuracy and the - standard error of the difference are returned. If more than two objects are - provided then a matrix of summary information is returned (see - \strong{Details}). +A matrix with class \code{'compare.loo'} that has its own + print method. See \strong{Details}. } \description{ Compare fitted models on LOO or WAIC. From 1ce802387ecf2f02e09f7dc351905db3166eb381 Mon Sep 17 00:00:00 2001 From: jgabry Date: Mon, 17 Sep 2018 16:19:32 +0200 Subject: [PATCH 03/17] export is.loo, is.waic, etc. [ci skip] --- NAMESPACE | 4 ++++ R/compare.R | 14 ++------------ R/helpers.R | 15 --------------- R/loo.R | 12 ++++++++++++ R/loo_compare.R | 10 +++++++--- R/psis.R | 5 +++++ R/waic.R | 6 ++++++ man/loo.Rd | 6 ++++++ man/waic.Rd | 3 +++ 9 files changed, 45 insertions(+), 30 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 6d4c43a3..14833a2c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -45,6 +45,10 @@ export(example_loglik_array) export(example_loglik_matrix) export(extract_log_lik) export(gpdfit) +export(is.loo) +export(is.psis) +export(is.psis_loo) +export(is.waic) export(kfold_split_balanced) export(kfold_split_random) export(kfold_split_stratified) diff --git a/R/compare.R b/R/compare.R index c388c418..d2f93b49 100644 --- a/R/compare.R +++ b/R/compare.R @@ -88,6 +88,7 @@ compare <- function(..., x = list()) { loo1 <- dots[[1]] loo2 <- dots[[2]] comp <- compare_two_models(loo1, loo2) + class(comp) <- c(class(comp), "old_compare.loo") return(comp) } else { Ns <- sapply(dots, function(x) nrow(x$pointwise)) @@ -116,7 +117,7 @@ compare <- function(..., x = list()) { se_diff <- apply(diffs, 2, se_elpd_diff) comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp) rownames(comp) <- rnms - class(comp) <- c("compare.loo", class(comp)) + class(comp) <- c("compare.loo", class(comp), "old_compare.loo") comp } } @@ -137,14 +138,3 @@ compare_two_models <- function(loo_a, loo_b, return = c("elpd_diff", "se"), chec comp <- c(elpd_diff = sum(diffs), se = se_elpd_diff(diffs)) structure(comp, class = "compare.loo") } - -# elpd_diffs <- function(loo_a, loo_b) { -# pt_a <- loo_a$pointwise -# pt_b <- loo_b$pointwise -# elpd <- grep("^elpd", colnames(pt_a)) -# pt_b[, elpd] - pt_a[, elpd] -# } -# se_elpd_diff <- function(diffs) { -# N <- length(diffs) -# sqrt(N) * sd(diffs) -# } diff --git a/R/helpers.R b/R/helpers.R index 6c3644b5..4d970fae 100644 --- a/R/helpers.R +++ b/R/helpers.R @@ -37,21 +37,6 @@ table_of_estimates <- function(x) { } -# checking classes -------------------------------------------------------- -is.psis <- function(x) { - inherits(x, "psis") && is.list(x) -} -is.loo <- function(x) { - inherits(x, "loo") -} -is.psis_loo <- function(x) { - inherits(x, "psis_loo") && is.loo(x) -} -is.waic <- function(x) { - inherits(x, "waic") && is.loo(x) -} - - # validating and reshaping arrays/matrices ------------------------------- #' Check for NAs and non-finite values in log-lik (or log-ratios) diff --git a/R/loo.R b/R/loo.R index 4555c069..e36168a4 100644 --- a/R/loo.R +++ b/R/loo.R @@ -397,6 +397,18 @@ dim.psis_loo <- function(x) { } +#' @rdname loo +#' @export +is.loo <- function(x) { + inherits(x, "loo") +} + +#' @rdname loo +#' @export +is.psis_loo <- function(x) { + inherits(x, "psis_loo") && is.loo(x) +} + # internal ---------------------------------------------------------------- diff --git a/R/loo_compare.R b/R/loo_compare.R index 33a70f6f..f896a998 100644 --- a/R/loo_compare.R +++ b/R/loo_compare.R @@ -124,8 +124,13 @@ loo_compare.default <- function(x, ...) { #' by default only the most important columns are printed. print.compare.loo <- function(x, ..., digits = 1, simplify = TRUE) { xcopy <- x - if (simplify) { - xcopy <- xcopy[, c("elpd_diff", "se_diff")] + if (inherits(xcopy, "old_compare.loo")) { + if (NCOL(xcopy) >= 2 && simplify) { + patts <- "^elpd_|^se_diff|^p_|^waic$|^looic$" + xcopy <- xcopy[, grepl(patts, colnames(xcopy))] + } + } else if (simplify) { + xcopy <- xcopy[, c("elpd_diff", "se_diff")] } print(.fr(xcopy, digits), quote = FALSE) invisible(x) @@ -134,7 +139,6 @@ print.compare.loo <- function(x, ..., digits = 1, simplify = TRUE) { # internal ---------------------------------------------------------------- - elpd_diffs <- function(loo_a, loo_b) { pt_a <- loo_a$pointwise pt_b <- loo_b$pointwise diff --git a/R/psis.R b/R/psis.R index b91ac0aa..690af8ea 100644 --- a/R/psis.R +++ b/R/psis.R @@ -170,6 +170,11 @@ dim.psis <- function(x) { attr(x, "dims") } +#' @export +is.psis <- function(x) { + inherits(x, "psis") && is.list(x) +} + # internal ---------------------------------------------------------------- diff --git a/R/waic.R b/R/waic.R index cf57f6aa..478801d4 100644 --- a/R/waic.R +++ b/R/waic.R @@ -121,6 +121,12 @@ dim.waic <- function(x) { attr(x, "dims") } +#' @rdname waic +#' @export +is.waic <- function(x) { + inherits(x, "waic") && is.loo(x) +} + # internal ---------------------------------------------------------------- diff --git a/man/loo.Rd b/man/loo.Rd index 3593883d..b68b6377 100644 --- a/man/loo.Rd +++ b/man/loo.Rd @@ -6,6 +6,8 @@ \alias{loo.matrix} \alias{loo.function} \alias{loo_i} +\alias{is.loo} +\alias{is.psis_loo} \title{Efficient approximate leave-one-out cross-validation (LOO)} \usage{ loo(x, ...) @@ -20,6 +22,10 @@ loo(x, ...) save_psis = FALSE, cores = getOption("mc.cores", 1)) loo_i(i, llfun, ..., data = NULL, draws = NULL, r_eff = NULL) + +is.loo(x) + +is.psis_loo(x) } \arguments{ \item{x}{A log-likelihood array, matrix, or function. See the \strong{Methods diff --git a/man/waic.Rd b/man/waic.Rd index 3cd302ce..09b1b796 100644 --- a/man/waic.Rd +++ b/man/waic.Rd @@ -5,6 +5,7 @@ \alias{waic.array} \alias{waic.matrix} \alias{waic.function} +\alias{is.waic} \title{Widely applicable information criterion (WAIC)} \usage{ waic(x, ...) @@ -14,6 +15,8 @@ waic(x, ...) \method{waic}{matrix}(x, ...) \method{waic}{function}(x, ..., data = NULL, draws = NULL) + +is.waic(x) } \arguments{ \item{x}{A log-likelihood array, matrix, or function. See the \strong{Methods From a654d21644eea33bb5dad1478ffdff43249288fe Mon Sep 17 00:00:00 2001 From: jgabry Date: Mon, 17 Sep 2018 16:29:10 +0200 Subject: [PATCH 04/17] update doc [ci skip] --- R/loo_compare.R | 26 ++++++++++++-------------- man/loo_compare.Rd | 26 ++++++++++++-------------- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/R/loo_compare.R b/R/loo_compare.R index f896a998..87c79906 100644 --- a/R/loo_compare.R +++ b/R/loo_compare.R @@ -6,8 +6,8 @@ #' @param x An object of class \code{"loo"} or a list of such objects. #' @param ... Additional objects of class \code{"loo"}. #' -#' @return A matrix with class \code{'compare.loo'} that has its own -#' print method. See \strong{Details}. +#' @return A matrix with class \code{"compare.loo"} that has its own +#' print method. See the \strong{Details} section for more . #' #' @details #' When comparing two fitted models, we can estimate the difference in their @@ -15,19 +15,17 @@ #' \code{elpd_waic} (or multiplied by \eqn{-2}, if desired, to be on the #' deviance scale). #' -#' \emph{When that difference, \code{elpd_diff}, is positive then the expected -#' predictive accuracy for the second model is higher. A negative -#' \code{elpd_diff} favors the first model.} +#' When using \code{loo_compare()}, the returned matrix will have one row per +#' model and several columns of estimates. The values in the \code{elpd_diff} +#' and \code{se_diff} columns of the returned matrix are computed by making +#' pairwise comparisons between each model and the model with the largest ELPD +#' (the model in the first row). For this reason the \code{elpd_diff} column +#' will always have the value \code{0} in the first row (i.e., the difference +#' between the preferred model and itself) and negative values in subsequent +#' rows for the remaining models. #' -#' When using \code{compare()} with more than two models, the values in the -#' \code{elpd_diff} and \code{se_diff} columns of the returned matrix are -#' computed by making pairwise comparisons between each model and the model -#' with the best ELPD (i.e., the model in the first row). -#' Although the \code{elpd_diff} column is equal to the difference in -#' \code{elpd_loo}, do not expect the \code{se_diff} column to be equal to the -#' the difference in \code{se_elpd_loo}. -#' -#' To compute the standard error of the difference in ELPD we use a +#' To compute the standard error of the difference in ELPD --- which should +#' not be expected to equal the difference of the standard errors --- we use a #' paired estimate to take advantage of the fact that the same set of \eqn{N} #' data points was used to fit both models. These calculations should be most #' useful when \eqn{N} is large, because then non-normality of the diff --git a/man/loo_compare.Rd b/man/loo_compare.Rd index f0ad1fd1..2528b6bd 100644 --- a/man/loo_compare.Rd +++ b/man/loo_compare.Rd @@ -25,8 +25,8 @@ of the summary matrix be printed? The entire matrix is always returned, but by default only the most important columns are printed.} } \value{ -A matrix with class \code{'compare.loo'} that has its own - print method. See \strong{Details}. +A matrix with class \code{"compare.loo"} that has its own + print method. See the \strong{Details} section for more . } \description{ Compare fitted models on LOO or WAIC. @@ -37,19 +37,17 @@ When comparing two fitted models, we can estimate the difference in their \code{elpd_waic} (or multiplied by \eqn{-2}, if desired, to be on the deviance scale). - \emph{When that difference, \code{elpd_diff}, is positive then the expected - predictive accuracy for the second model is higher. A negative - \code{elpd_diff} favors the first model.} + When using \code{loo_compare()}, the returned matrix will have one row per + model and several columns of estimates. The values in the \code{elpd_diff} + and \code{se_diff} columns of the returned matrix are computed by making + pairwise comparisons between each model and the model with the largest ELPD + (the model in the first row). For this reason the \code{elpd_diff} column + will always have the value \code{0} in the first row (i.e., the difference + between the preferred model and itself) and negative values in subsequent + rows for the remaining models. - When using \code{compare()} with more than two models, the values in the - \code{elpd_diff} and \code{se_diff} columns of the returned matrix are - computed by making pairwise comparisons between each model and the model - with the best ELPD (i.e., the model in the first row). - Although the \code{elpd_diff} column is equal to the difference in - \code{elpd_loo}, do not expect the \code{se_diff} column to be equal to the - the difference in \code{se_elpd_loo}. - - To compute the standard error of the difference in ELPD we use a + To compute the standard error of the difference in ELPD --- which should + not be expected to equal the difference of the standard errors --- we use a paired estimate to take advantage of the fact that the same set of \eqn{N} data points was used to fit both models. These calculations should be most useful when \eqn{N} is large, because then non-normality of the From db2e8ba00d7bbc89a49a1101faca7d8fb0c63b30 Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 18 Sep 2018 12:20:19 +0200 Subject: [PATCH 05/17] export find_model_names() --- NAMESPACE | 1 + R/loo_compare.R | 110 +++++++++++++++++++++++++--------------- man/find_model_names.Rd | 18 +++++++ 3 files changed, 88 insertions(+), 41 deletions(-) create mode 100644 man/find_model_names.Rd diff --git a/NAMESPACE b/NAMESPACE index 14833a2c..fcbb1885 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -44,6 +44,7 @@ export(compare) export(example_loglik_array) export(example_loglik_matrix) export(extract_log_lik) +export(find_model_names) export(gpdfit) export(is.loo) export(is.psis) diff --git a/R/loo_compare.R b/R/loo_compare.R index 87c79906..cf160354 100644 --- a/R/loo_compare.R +++ b/R/loo_compare.R @@ -59,58 +59,54 @@ loo_compare <- function(x, ...) { loo_compare.default <- function(x, ...) { if (is.loo(x)) { dots <- list(...) - dots <- c(list(x), dots) - nms <- c(as.character(match.call(expand.dots = TRUE))[-1L]) + loos <- c(list(x), dots) } else { if (!is.list(x) || !length(x)) { - stop("'x' must be a list if not a 'loo' object.", call. = FALSE) + stop("'x' must be a list if not a 'loo' object.") } if (length(list(...))) { - stop("If 'x' is a list then '...' should not be specified.", call. = FALSE) - } - dots <- x - nms <- names(dots) - if (!length(nms)) { - nms <- paste0("model", seq_along(dots)) + stop("If 'x' is a list then '...' should not be specified.") } + loos <- x } - if (!all(sapply(dots, is.loo))) { + if (!all(sapply(loos, is.loo))) { stop("All inputs should have class 'loo'.") } - if (length(dots) <= 1L) { + if (length(loos) <= 1L) { stop("'loo_compare' requires at least two models.") - } else { - Ns <- sapply(dots, function(x) nrow(x$pointwise)) - if (!all(Ns == Ns[1L])) { - stop("Not all models have the same number of data points.", call. = FALSE) - } + } - x <- sapply(dots, function(x) { - est <- x$estimates - setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))) ) - }) - - colnames(x) <- nms - rnms <- rownames(x) - comp <- x - ord <- order(x[grep("^elpd", rnms), ], decreasing = TRUE) - comp <- t(comp)[ord, ] - patts <- c("elpd", "p_", "^waic$|^looic$", "^se_waic$|^se_looic$") - col_ord <- unlist(sapply(patts, function(p) grep(p, colnames(comp))), - use.names = FALSE) - comp <- comp[, col_ord] - - # compute elpd_diff and se_elpd_diff relative to best model - rnms <- rownames(comp) - diffs <- mapply(elpd_diffs, dots[ord[1]], dots[ord]) - elpd_diff <- apply(diffs, 2, sum) - se_diff <- apply(diffs, 2, se_elpd_diff) - comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp) - rownames(comp) <- rnms - class(comp) <- c("compare.loo", class(comp)) - comp + Ns <- sapply(loos, function(x) nrow(x$pointwise)) + if (!all(Ns == Ns[1L])) { + stop("Not all models have the same number of data points.") } + + tmp <- sapply(loos, function(x) { + est <- x$estimates + setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))) ) + }) + + colnames(tmp) <- find_model_names(loos) + rnms <- rownames(tmp) + comp <- tmp + ord <- order(tmp[grep("^elpd", rnms), ], decreasing = TRUE) + comp <- t(comp)[ord, ] + patts <- c("elpd", "p_", "^waic$|^looic$", "^se_waic$|^se_looic$") + col_ord <- unlist(sapply(patts, function(p) grep(p, colnames(comp))), + use.names = FALSE) + comp <- comp[, col_ord] + + # compute elpd_diff and se_elpd_diff relative to best model + rnms <- rownames(comp) + diffs <- mapply(elpd_diffs, loos[ord[1]], loos[ord]) + elpd_diff <- apply(diffs, 2, sum) + se_diff <- apply(diffs, 2, se_elpd_diff) + comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp) + rownames(comp) <- rnms + + class(comp) <- c("compare.loo", class(comp)) + return(comp) } #' @rdname loo_compare @@ -127,7 +123,7 @@ print.compare.loo <- function(x, ..., digits = 1, simplify = TRUE) { patts <- "^elpd_|^se_diff|^p_|^waic$|^looic$" xcopy <- xcopy[, grepl(patts, colnames(xcopy))] } - } else if (simplify) { + } else if (NCOL(xcopy) >= 2 && simplify) { xcopy <- xcopy[, c("elpd_diff", "se_diff")] } print(.fr(xcopy, digits), quote = FALSE) @@ -147,3 +143,35 @@ se_elpd_diff <- function(diffs) { N <- length(diffs) sqrt(N) * sd(diffs) } + + + +#' Find the model names associated with loo objects +#' +#' @export +#' @keywords internal +#' @param x List of loo objects. +#' @return Character vector of model names the same length as x. +#' +find_model_names <- function(x) { + out_names <- character(length(x)) + + names1 <- names(x) + names2 <- lapply(x, "attr", "model_name", exact = TRUE) + names3 <- lapply(x, "[[", "model_name") + names4 <- paste0("model", seq_along(x)) + + for (j in seq_along(x)) { + if (isTRUE(nzchar(names1[j]))) { + out_names[j] <- names1[j] + } else if (length(names2[[j]])) { + out_names[j] <- names2[[j]] + } else if (length(names3[[j]])) { + out_names[j] <- names3[[j]] + } else { + out_names[j] <- names4[j] + } + } + + return(out_names) +} diff --git a/man/find_model_names.Rd b/man/find_model_names.Rd new file mode 100644 index 00000000..575665c7 --- /dev/null +++ b/man/find_model_names.Rd @@ -0,0 +1,18 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/loo_compare.R +\name{find_model_names} +\alias{find_model_names} +\title{Find the model names associated with loo objects} +\usage{ +find_model_names(x) +} +\arguments{ +\item{x}{List of loo objects.} +} +\value{ +Character vector of model names the same length as x. +} +\description{ +Find the model names associated with loo objects +} +\keyword{internal} From 33f332875dc3e37af4ca9600e79472152f610a13 Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 18 Sep 2018 14:28:35 +0200 Subject: [PATCH 06/17] loo_compare: check if list --- R/loo_compare.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/loo_compare.R b/R/loo_compare.R index cf160354..a22b77a3 100644 --- a/R/loo_compare.R +++ b/R/loo_compare.R @@ -154,6 +154,7 @@ se_elpd_diff <- function(diffs) { #' @return Character vector of model names the same length as x. #' find_model_names <- function(x) { + stopifnot(is.list(x)) out_names <- character(length(x)) names1 <- names(x) From 91bbaf95ca3f9aac5e8dd0420f7efa0ce04c7766 Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 18 Sep 2018 14:38:07 +0200 Subject: [PATCH 07/17] update tests [ci skip] --- compare_three_models.rds | Bin 0 -> 346 bytes compare_two_models.rds | Bin 0 -> 127 bytes tests/testthat/test_compare.R | 19 ++++++++++++------- 3 files changed, 12 insertions(+), 7 deletions(-) create mode 100644 compare_three_models.rds create mode 100644 compare_two_models.rds diff --git a/compare_three_models.rds b/compare_three_models.rds new file mode 100644 index 0000000000000000000000000000000000000000..a7abb8a5d358ffd5ded78c663f17645169fbb43d GIT binary patch literal 346 zcmV-g0j2&QiwFP!000001B>8dU|?WkU}j}xU}6R`nfZW(1OpTt5U>i4pX2)GfC2BR ziQ8D<;`ZCJ+_&fpi`buD|MsEDTbBbN4?o;J=@fAw^yxI=!)!YbL_AGyGCQ2)Ajo06 zS+F3%K~QR{5bsTU2Z2wlcQO_TIdEwlNW8i;$U!ZA>T8AzoeqvCZUjl)<+T6!H}y@? z=Xm>%vlHgzrXREa*l=M|`pI~Q6o~l_sSxuWGN9%wLd{o%ny&ycpMime3Fu5l1{MZR zkf%~IbAc2u*sDM;5Obg@;Q%VhOUz9z2C{{r>IA@QA@WS+hH#bl`|8;1G*Q zdP+j#1C~o`Dm!}>+~(vH%i3HmlxILIVJ_J diu0?p3+b@8S67vboYDADp5bL{(3_P&djU?{F+TtR literal 0 HcmV?d00001 diff --git a/tests/testthat/test_compare.R b/tests/testthat/test_compare.R index 8eade396..15373cc0 100644 --- a/tests/testthat/test_compare.R +++ b/tests/testthat/test_compare.R @@ -10,6 +10,11 @@ LLarr3 <- array(rnorm(prod(dim(LLarr)), c(LLarr), 1), dim = dim(LLarr)) w1 <- SW(waic(LLarr)) w2 <- SW(waic(LLarr2)) + + + +# Tests for deprecated compare() ------------------------------------------ + test_that("compare throws appropriate errors", { w3 <- SW(waic(LLarr[,, -1])) w4 <- SW(waic(LLarr[,, -(1:2)])) @@ -26,27 +31,27 @@ test_that("compare throws appropriate errors", { regexp = "same number of data points") expect_error(loo::compare(w1, w2, w3), regexp = "same number of data points") - expect_silent(loo::compare(w1, w2)) - expect_silent(loo::compare(w1, w1, w2)) + expect_warning(loo::compare(w1, w2), "Deprecated") + expect_warning(loo::compare(w1, w1, w2), "Deprecated") }) test_that("compare returns expected result (2 models)", { - comp1 <- compare(w1, w1) + comp1 <- loo::compare(w1, w1) expect_output(print(comp1), "elpd_diff") expect_equal(comp1[1:2], c(elpd_diff = 0, se = 0)) - comp2 <- compare(w1, w2) + comp2 <- loo::compare(w1, w2) expect_equal_to_reference(comp2, "compare_two_models.rds") expect_named(comp2, c("elpd_diff", "se")) expect_s3_class(comp2, "compare.loo") # specifying objects via ... and via arg x gives equal results - expect_equal(comp2, compare(x = list(w1, w2))) + expect_equal(comp2, loo::compare(x = list(w1, w2))) }) test_that("compare returns expected result (3 models)", { w3 <- SW(waic(LLarr3)) - comp1 <- compare(w1, w2, w3) + comp1 <- loo::compare(w1, w2, w3) expect_equal( colnames(comp1), @@ -62,5 +67,5 @@ test_that("compare returns expected result (3 models)", { # specifying objects via '...' gives equivalent results (equal # except rownames) to using 'x' argument - expect_equivalent(comp1, compare(x = list(w1, w2, w3))) + expect_equivalent(comp1, loo::compare(x = list(w1, w2, w3))) }) From d96618d9d36b5d148ffe15ae4b58cdfca56c92d7 Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 18 Sep 2018 15:42:00 +0200 Subject: [PATCH 08/17] kfold generic --- NAMESPACE | 1 + R/kfold-generic.R | 28 ++++++++++++++++++++++++++++ R/kfold-helpers.R | 22 ++++++++++++++-------- R/loo_compare.R | 2 +- man/kfold-generic.Rd | 31 +++++++++++++++++++++++++++++++ man/kfold-helpers.Rd | 20 ++++++++++++-------- 6 files changed, 87 insertions(+), 17 deletions(-) create mode 100644 R/kfold-generic.R create mode 100644 man/kfold-generic.Rd diff --git a/NAMESPACE b/NAMESPACE index fcbb1885..71b86d2d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -50,6 +50,7 @@ export(is.loo) export(is.psis) export(is.psis_loo) export(is.waic) +export(kfold) export(kfold_split_balanced) export(kfold_split_random) export(kfold_split_stratified) diff --git a/R/kfold-generic.R b/R/kfold-generic.R new file mode 100644 index 00000000..71ed9467 --- /dev/null +++ b/R/kfold-generic.R @@ -0,0 +1,28 @@ +#' Generic function for K-fold cross-validation for developers +#' +#' For developers of modeling packages, \pkg{loo} includes a generic function +#' \code{kfold} so that methods may be defined for K-fold CV without name +#' conflicts between packages. See, e.g., the \code{kfold.stanreg} method in +#' \pkg{rstanarm} and the \code{kfold.brmsfit} method in \pkg{brms}. +#' +#' @name kfold-generic +#' @param x A fitted model object. +#' @param ... Arguments to pass to specific methods. +#' +#' @return For developers defining a \code{kfold} method for a class +#' \code{"foo"}, the \code{kfold.foo} function should return a list with class +#' \code{c("kfold", "loo")} with at least the elements +#' \itemize{ +#' \item \code{"estimates"}: a 1x2 matrix with column names "Estimate" and "SE" +#' containing the ELPD estimate and its standard error. +#' \item \code{"pointwise"}: an Nx1 matrix with column name "elpd_kfold" containing +#' the pointwise contributions for each data point. +#' } +#' +NULL + +#' @rdname kfold-generic +#' @export +kfold <- function(x, ...) { + UseMethod("kfold") +} diff --git a/R/kfold-helpers.R b/R/kfold-helpers.R index ddf0065c..815902bd 100644 --- a/R/kfold-helpers.R +++ b/R/kfold-helpers.R @@ -1,17 +1,23 @@ #' Helper functions for K-fold cross-validation #' -#' These functions can be used to generate indexes for use with K-fold -#' cross-validation. +#' @description These functions can be used to generate indexes for use with +#' K-fold cross-validation. See the \strong{Details} section for explanations. +#' +#' For package developers, see also \link{kfold-generic} for tips on defining +#' \code{kfold} methods for your fitted model objects. #' #' @name kfold-helpers #' @param K The number of folds to use. #' @param N The number of observations in the data. -#' @param x A discrete variable of length \code{N}. Will be coerced to -#' \code{\link{factor}}. For \code{kfold_split_balanced} \code{x} should be a -#' binary variable. For \code{kfold_split_stratified} \code{x} should be a -#' grouping variable with at least \code{K} levels. -#' @return An integer vector of length \code{N} where each element is an index -#' in \code{1:K}. +#' @param x For the helper functions, \code{x} should be a discrete variable of +#' length \code{N} (will be coerced to \code{\link{factor}}). For +#' \code{kfold_split_balanced} \code{x} should be a binary variable. For +#' \code{kfold_split_stratified} \code{x} should be a grouping variable with +#' at least \code{K} levels. + +#' @return The helper functions return an integer vector of length \code{N} +#' where each element is an index in \code{1:K}. +#' #' #' @details #' \code{kfold_split_random} splits the data into \code{K} groups diff --git a/R/loo_compare.R b/R/loo_compare.R index a22b77a3..78e184bc 100644 --- a/R/loo_compare.R +++ b/R/loo_compare.R @@ -99,7 +99,7 @@ loo_compare.default <- function(x, ...) { # compute elpd_diff and se_elpd_diff relative to best model rnms <- rownames(comp) - diffs <- mapply(elpd_diffs, loos[ord[1]], loos[ord]) + diffs <- mapply(FUN = elpd_diffs, loos[ord[1]], loos[ord]) elpd_diff <- apply(diffs, 2, sum) se_diff <- apply(diffs, 2, se_elpd_diff) comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp) diff --git a/man/kfold-generic.Rd b/man/kfold-generic.Rd new file mode 100644 index 00000000..a1a6bceb --- /dev/null +++ b/man/kfold-generic.Rd @@ -0,0 +1,31 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/kfold-generic.R +\name{kfold-generic} +\alias{kfold-generic} +\alias{kfold} +\title{Generic function for K-fold cross-validation for developers} +\usage{ +kfold(x, ...) +} +\arguments{ +\item{x}{A fitted model object.} + +\item{...}{Arguments to pass to specific methods.} +} +\value{ +For developers defining a \code{kfold} method for a class + \code{"foo"}, the \code{kfold.foo} function should return a list with class + \code{c("kfold", "loo")} with at least the elements + \itemize{ + \item \code{"estimates"}: a 1x2 matrix with column names "Estimate" and "SE" + containing the ELPD estimate and its standard error. + \item \code{"pointwise"}: an Nx1 matrix with column name "elpd_kfold" containing + the pointwise contributions for each data point. + } +} +\description{ +For developers of modeling packages, \pkg{loo} includes a generic function +\code{kfold} so that methods may be defined for K-fold CV without name +conflicts between packages. See, e.g., the \code{kfold.stanreg} method in +\pkg{rstanarm} and the \code{kfold.brmsfit} method in \pkg{brms}. +} diff --git a/man/kfold-helpers.Rd b/man/kfold-helpers.Rd index 9435ee1b..853663b9 100644 --- a/man/kfold-helpers.Rd +++ b/man/kfold-helpers.Rd @@ -18,18 +18,22 @@ kfold_split_stratified(K = 10, x = NULL) \item{N}{The number of observations in the data.} -\item{x}{A discrete variable of length \code{N}. Will be coerced to -\code{\link{factor}}. For \code{kfold_split_balanced} \code{x} should be a -binary variable. For \code{kfold_split_stratified} \code{x} should be a -grouping variable with at least \code{K} levels.} +\item{x}{For the helper functions, \code{x} should be a discrete variable of +length \code{N} (will be coerced to \code{\link{factor}}). For +\code{kfold_split_balanced} \code{x} should be a binary variable. For +\code{kfold_split_stratified} \code{x} should be a grouping variable with +at least \code{K} levels.} } \value{ -An integer vector of length \code{N} where each element is an index - in \code{1:K}. +The helper functions return an integer vector of length \code{N} +where each element is an index in \code{1:K}. } \description{ -These functions can be used to generate indexes for use with K-fold -cross-validation. +These functions can be used to generate indexes for use with + K-fold cross-validation. See the \strong{Details} section for explanations. + + For package developers, see also \link{kfold-generic} for tips on defining + \code{kfold} methods for your fitted model objects. } \details{ \code{kfold_split_random} splits the data into \code{K} groups From 87017677889bd3c9d078cebc42b07f3e486b9c51 Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 18 Sep 2018 15:52:09 +0200 Subject: [PATCH 09/17] remove unused rds files --- compare_three_models.rds | Bin 346 -> 0 bytes compare_two_models.rds | Bin 127 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 compare_three_models.rds delete mode 100644 compare_two_models.rds diff --git a/compare_three_models.rds b/compare_three_models.rds deleted file mode 100644 index a7abb8a5d358ffd5ded78c663f17645169fbb43d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 346 zcmV-g0j2&QiwFP!000001B>8dU|?WkU}j}xU}6R`nfZW(1OpTt5U>i4pX2)GfC2BR ziQ8D<;`ZCJ+_&fpi`buD|MsEDTbBbN4?o;J=@fAw^yxI=!)!YbL_AGyGCQ2)Ajo06 zS+F3%K~QR{5bsTU2Z2wlcQO_TIdEwlNW8i;$U!ZA>T8AzoeqvCZUjl)<+T6!H}y@? z=Xm>%vlHgzrXREa*l=M|`pI~Q6o~l_sSxuWGN9%wLd{o%ny&ycpMime3Fu5l1{MZR zkf%~IbAc2u*sDM;5Obg@;Q%VhOUz9z2C{{r>IA@QA@WS+hH#bl`|8;1G*Q zdP+j#1C~o`Dm!}>+~(vH%i3HmlxILIVJ_J diu0?p3+b@8S67vboYDADp5bL{(3_P&djU?{F+TtR From a4ee11cbbd3409bc6fda3b1192c93e737d673641 Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 18 Sep 2018 16:11:18 +0200 Subject: [PATCH 10/17] print_dims.kfold [ci skip] --- NAMESPACE | 1 + R/print.R | 9 +++++++++ man/print_dims.Rd | 3 +++ 3 files changed, 13 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index 71b86d2d..12c750dd 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -25,6 +25,7 @@ S3method(print,psis) S3method(print,psis_loo) S3method(print,stacking_weights) S3method(print,waic) +S3method(print_dims,kfold) S3method(print_dims,psis) S3method(print_dims,psis_loo) S3method(print_dims,waic) diff --git a/R/print.R b/R/print.R index a4dfb524..cc546c61 100644 --- a/R/print.R +++ b/R/print.R @@ -104,6 +104,15 @@ print_dims.waic <- function(x, ...) { ) } +#' @rdname print_dims +#' @export +print_dims.kfold <- function(x, ...) { + K <- attr(x, "K", exact = TRUE) + if (!is.null(K)) { + cat("Based on", paste0(K, "-fold"), "cross-validation\n") + } +} + print_mcse_summary <- function(x, digits) { mcse_val <- mcse_loo(x) diff --git a/man/print_dims.Rd b/man/print_dims.Rd index 3f275498..2bab964e 100644 --- a/man/print_dims.Rd +++ b/man/print_dims.Rd @@ -5,6 +5,7 @@ \alias{print_dims.psis} \alias{print_dims.psis_loo} \alias{print_dims.waic} +\alias{print_dims.kfold} \title{Print dimensions of log-likelihood or log-weights matrix} \usage{ print_dims(x, ...) @@ -14,6 +15,8 @@ print_dims(x, ...) \method{print_dims}{psis_loo}(x, ...) \method{print_dims}{waic}(x, ...) + +\method{print_dims}{kfold}(x, ...) } \arguments{ \item{x}{The object returned by \code{\link{psis}}, \code{\link{loo}}, or From 506106cd41638011dba71269c31381a76b9efac8 Mon Sep 17 00:00:00 2001 From: jgabry Date: Thu, 20 Dec 2018 12:55:05 -0500 Subject: [PATCH 11/17] update tests for loo_compare --- tests/testthat/compare_three_models.rds | Bin 339 -> 346 bytes tests/testthat/compare_two_models.rds | Bin 120 -> 127 bytes tests/testthat/loo_compare_three_models.rds | Bin 0 -> 343 bytes tests/testthat/loo_compare_two_models.rds | Bin 0 -> 277 bytes tests/testthat/test_compare.R | 81 +++++++++++++++----- 5 files changed, 63 insertions(+), 18 deletions(-) create mode 100644 tests/testthat/loo_compare_three_models.rds create mode 100644 tests/testthat/loo_compare_two_models.rds diff --git a/tests/testthat/compare_three_models.rds b/tests/testthat/compare_three_models.rds index 76d81dd4bd9bb70b0aed7612390a7d4825627022..a7abb8a5d358ffd5ded78c663f17645169fbb43d 100644 GIT binary patch delta 75 zcmcc2bc;z;zMF#q44AtgBqbyyBqgM!rfd|IWYn3(zhm;}Vs4@Cpsz)rJ|~(zRhuS~ d8s<>>=clFS#}i!=68{((JbzYJ$}=%A007vY8{_~0 delta 68 zcmcb`beTz1zMF#q44AtgBqbyyBqgLJCTl5Jk?#Z7!O-B}m?49$cW`^(1 VPF2^Lq7K*W8T^x^Lphij7yvuu80-K5 diff --git a/tests/testthat/compare_two_models.rds b/tests/testthat/compare_two_models.rds index 48ddbdbf1d669b27e09d0d8358332beb7ebe40a1..7261c8fb1f52be354462aa95f4c446b02c41a9a9 100644 GIT binary patch delta 63 zcmb=Z7nSekU;qQ=?gvQ;2?8dU|?WkU}j}xU}6R`nfZW(1OpTt5U>i4pX2)GfC2BR ziQ8D<;`ZCJ+_&fpi`buD|MsEDTbBbN4?o;J=@fAw^yxI=!)!YbL_AGyGCQ2)Ajo06 zS+F3%K~QR{5bsTU2Z2wlcQO_TIdEwlNW8i;$U!ZA>T8AzoeqvCZUjl)<+T6!H}y@? z=Xm>%vlHgzrXREa*l=M|`pI~Q6o~l_sSxuWGN9%wLd{o%ny&ycpMime3Fu5l1{MZR zkf%~IbAc2u*sDM;5Obg@;Q%VhOUz9z2C{{r>IA@QA@Xdw`6;P6hA2EE6rM3u2M1J? zGc~6mB|arHEe$HbUYv@|g~*pDW+p=gc!2WgVr&J-JWikr6af|_HpqA|7u92|$vKI| p#Zb2~K?S*!^K%Oli&FJ+^7G-INGvJJtN^n90RZPk0O=G0007{7mVf{N literal 0 HcmV?d00001 diff --git a/tests/testthat/loo_compare_two_models.rds b/tests/testthat/loo_compare_two_models.rds new file mode 100644 index 0000000000000000000000000000000000000000..546024d2e93ec087033b2cafc768831a2e1ea9bc GIT binary patch literal 277 zcmV+w0qXuAiwFP!000001B>8dU|?WkU}j}xU}6R`nfZW(00R^p5U>i4pX2%l#<$;= z<-SEKk3cyHP}aA_P!yt*^UK`nji zYlaJ*_8lhhW7&t+WO3BOx zQoLZNLNEuK5)Pn}yu{qpVjx=>szv}zL*&_V^HWlD3{iMSP&FJ-QO?wyf|U4_%(OJ9 z0DEyNG8ZCWo|u^o72pBNql>W>AoDnZDo_Mikk}yO!CX|YuqNju78j$rpF24}w;-`7 bRWBz$AMT07lA_ECAnP9hn0H#T#sL5TKYVnH literal 0 HcmV?d00001 diff --git a/tests/testthat/test_compare.R b/tests/testthat/test_compare.R index 15373cc0..165631ce 100644 --- a/tests/testthat/test_compare.R +++ b/tests/testthat/test_compare.R @@ -2,7 +2,7 @@ library(loo) set.seed(123) SW <- suppressWarnings -context("compare") +context("compare models") LLarr <- example_loglik_array() LLarr2 <- array(rnorm(prod(dim(LLarr)), c(LLarr), 0.5), dim = dim(LLarr)) @@ -11,47 +11,91 @@ w1 <- SW(waic(LLarr)) w2 <- SW(waic(LLarr2)) - - -# Tests for deprecated compare() ------------------------------------------ - -test_that("compare throws appropriate errors", { +test_that("loo_compare throws appropriate errors", { w3 <- SW(waic(LLarr[,, -1])) w4 <- SW(waic(LLarr[,, -(1:2)])) - expect_error(loo::compare(w1, w2, x = list(w1, w2)), - regexp = "If 'x' is specified then '...' should not be specified") - expect_error(loo::compare(w1, list(1,2,3)), + expect_error(loo_compare(w1, w2, x = list(w1, w2)), + regexp = "If 'x' is a list then '...' should not be specified") + expect_error(loo_compare(w1, list(1,2,3)), regexp = "class 'loo'") - expect_error(loo::compare(w1), + expect_error(loo_compare(w1), regexp = "requires at least two models") - expect_error(loo::compare(x = list(w1)), + expect_error(loo_compare(x = list(w1)), regexp = "requires at least two models") - expect_error(loo::compare(w1, w3), + expect_error(loo_compare(w1, w3), regexp = "same number of data points") - expect_error(loo::compare(w1, w2, w3), + expect_error(loo_compare(w1, w2, w3), regexp = "same number of data points") +}) + + + +comp_colnames <- c( + "elpd_diff", "se_diff", "elpd_waic", "se_elpd_waic", + "p_waic", "se_p_waic", "waic", "se_waic" +) + +test_that("loo_compare returns expected results (2 models)", { + comp1 <- loo_compare(w1, w1) + expect_s3_class(comp1, "compare.loo") + expect_equal(colnames(comp1), comp_colnames) + expect_equal(rownames(comp1), c("model1", "model2")) + expect_output(print(comp1), "elpd_diff") + expect_equivalent(comp1[1:2,1], c(0, 0)) + expect_equivalent(comp1[1:2,2], c(0, 0)) + + comp2 <- loo_compare(w1, w2) + expect_s3_class(comp2, "compare.loo") + expect_equal_to_reference(comp2, "loo_compare_two_models.rds") + expect_equal(colnames(comp2), comp_colnames) + + # specifying objects via ... and via arg x gives equal results + expect_equal(comp2, loo_compare(x = list(w1, w2))) +}) + + +test_that("loo_compare returns expected result (3 models)", { + w3 <- SW(waic(LLarr3)) + comp1 <- loo_compare(w1, w2, w3) + + expect_equal(colnames(comp1), comp_colnames) + expect_equal(rownames(comp1), c("model1", "model2", "model3")) + expect_equal(comp1[1,1], 0) + expect_s3_class(comp1, "compare.loo") + expect_s3_class(comp1, "matrix") + expect_equal_to_reference(comp1, "loo_compare_three_models.rds") + + # specifying objects via '...' gives equivalent results (equal + # except rownames) to using 'x' argument + expect_equivalent(comp1, loo_compare(x = list(w1, w2, w3))) +}) + +# Tests for deprecated compare() ------------------------------------------ + +test_that("compare throws deprecation warnings", { expect_warning(loo::compare(w1, w2), "Deprecated") expect_warning(loo::compare(w1, w1, w2), "Deprecated") }) test_that("compare returns expected result (2 models)", { - comp1 <- loo::compare(w1, w1) + comp1 <- expect_warning(loo::compare(w1, w1), "Deprecated") expect_output(print(comp1), "elpd_diff") expect_equal(comp1[1:2], c(elpd_diff = 0, se = 0)) - comp2 <- loo::compare(w1, w2) + comp2 <- expect_warning(loo::compare(w1, w2), "Deprecated") expect_equal_to_reference(comp2, "compare_two_models.rds") expect_named(comp2, c("elpd_diff", "se")) expect_s3_class(comp2, "compare.loo") # specifying objects via ... and via arg x gives equal results - expect_equal(comp2, loo::compare(x = list(w1, w2))) + comp_via_list <- expect_warning(loo::compare(x = list(w1, w2)), "Deprecated") + expect_equal(comp2, comp_via_list) }) test_that("compare returns expected result (3 models)", { w3 <- SW(waic(LLarr3)) - comp1 <- loo::compare(w1, w2, w3) + comp1 <- expect_warning(loo::compare(w1, w2, w3), "Deprecated") expect_equal( colnames(comp1), @@ -67,5 +111,6 @@ test_that("compare returns expected result (3 models)", { # specifying objects via '...' gives equivalent results (equal # except rownames) to using 'x' argument - expect_equivalent(comp1, loo::compare(x = list(w1, w2, w3))) + comp_via_list <- expect_warning(loo::compare(x = list(w1, w2, w3)), "Deprecated") + expect_equivalent(comp1, comp_via_list) }) From 97c29fe6fb1a6839ccbb681209082f6059559f5c Mon Sep 17 00:00:00 2001 From: jgabry Date: Sat, 22 Dec 2018 13:05:03 -0500 Subject: [PATCH 12/17] is.kfold() helper function [ci skip] --- NAMESPACE | 1 + R/kfold-generic.R | 6 ++++++ man/kfold-generic.Rd | 3 +++ 3 files changed, 10 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index 12c750dd..838d9ca3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -47,6 +47,7 @@ export(example_loglik_matrix) export(extract_log_lik) export(find_model_names) export(gpdfit) +export(is.kfold) export(is.loo) export(is.psis) export(is.psis_loo) diff --git a/R/kfold-generic.R b/R/kfold-generic.R index 71ed9467..8e7e2f63 100644 --- a/R/kfold-generic.R +++ b/R/kfold-generic.R @@ -26,3 +26,9 @@ NULL kfold <- function(x, ...) { UseMethod("kfold") } + +#' @rdname kfold-generic +#' @export +is.kfold <- function(x) { + inherits(x, "kfold") && is.loo(x) +} diff --git a/man/kfold-generic.Rd b/man/kfold-generic.Rd index a1a6bceb..ab81e55d 100644 --- a/man/kfold-generic.Rd +++ b/man/kfold-generic.Rd @@ -3,9 +3,12 @@ \name{kfold-generic} \alias{kfold-generic} \alias{kfold} +\alias{is.kfold} \title{Generic function for K-fold cross-validation for developers} \usage{ kfold(x, ...) + +is.kfold(x) } \arguments{ \item{x}{A fitted model object.} From df391412f61e3c9f35174da2a26ee259b60bb5d8 Mon Sep 17 00:00:00 2001 From: Jonah Gabry Date: Thu, 3 Jan 2019 15:05:06 -0500 Subject: [PATCH 13/17] Update NEWS.md --- NEWS.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/NEWS.md b/NEWS.md index e9591cfc..d0c7d608 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,14 +1,16 @@ # loo 2.0.0.9000 +* Introduced new function `loo_compare()` to eventually replace `compare()`, +which now throws a deprecation warning. (#93) + * New vignette on LOO for non-factorizable joint Gaussian models. (#75) -* When comparing more than two models with `compare()` there is now also -an `se_diff` column in the results. The printed output (the returned object) -from `compare()` has also been simplified. (#78) +* New `se_diff` column in model comparison results. (#78) -* Fix for `psis()` when `log_ratios` are very small. (#74) +* Improved behavior of `psis()` when `log_ratios` are very small. (#74) -* Allow `r_eff=NA` to suppress warning from `psis()` when specifying `r_eff` is not applicable (i.e., draws not from MCMC). (#72) +* Allow `r_eff=NA` to suppress warning from `psis()` when specifying `r_eff` +is not applicable (i.e., draws not from MCMC). (#72) # loo 2.0.0 From ef43c0c2b657d8ff95be306c0fcd9f8775e86592 Mon Sep 17 00:00:00 2001 From: jgabry Date: Sat, 5 Jan 2019 14:40:47 -0500 Subject: [PATCH 14/17] fix r cmd check warning --- R/psis.R | 1 + man/kfold-helpers.Rd | 7 ++----- man/psis.Rd | 3 +++ 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/R/psis.R b/R/psis.R index 690af8ea..d9b439c2 100644 --- a/R/psis.R +++ b/R/psis.R @@ -170,6 +170,7 @@ dim.psis <- function(x) { attr(x, "dims") } +#' @rdname psis #' @export is.psis <- function(x) { inherits(x, "psis") && is.list(x) diff --git a/man/kfold-helpers.Rd b/man/kfold-helpers.Rd index 38961683..df214636 100644 --- a/man/kfold-helpers.Rd +++ b/man/kfold-helpers.Rd @@ -19,13 +19,10 @@ kfold_split_grouped(K = 10, x = NULL) \item{N}{The number of observations in the data.} \item{x}{A discrete variable of length \code{N} with at least \code{K} levels -(unique values). Will be coerced to \code{\link{factor}}. -.} - +(unique values). Will be coerced to \code{\link{factor}}.} } \value{ -The helper functions return an integer vector of length \code{N} -where each element is an index in \code{1:K}. +An integer vector of length \code{N} where each element is an index in \code{1:K}. } \description{ These functions can be used to generate indexes for use with diff --git a/man/psis.Rd b/man/psis.Rd index 2755a1c6..e1ecaccb 100644 --- a/man/psis.Rd +++ b/man/psis.Rd @@ -6,6 +6,7 @@ \alias{psis.matrix} \alias{psis.default} \alias{weights.psis} +\alias{is.psis} \title{Pareto smoothed importance sampling (PSIS)} \usage{ psis(log_ratios, ...) @@ -19,6 +20,8 @@ psis(log_ratios, ...) \method{psis}{default}(log_ratios, ..., r_eff = NULL) \method{weights}{psis}(object, ..., log = TRUE, normalize = TRUE) + +is.psis(x) } \arguments{ \item{log_ratios}{An array, matrix, or vector of importance ratios on the log From e3ddd59ff49816ef4c1898679a6c812f6cfaf474 Mon Sep 17 00:00:00 2001 From: jgabry Date: Sat, 5 Jan 2019 14:51:09 -0500 Subject: [PATCH 15/17] document is.psis --- R/psis.R | 1 + man/psis.Rd | 2 ++ 2 files changed, 3 insertions(+) diff --git a/R/psis.R b/R/psis.R index d9b439c2..ef767b6f 100644 --- a/R/psis.R +++ b/R/psis.R @@ -172,6 +172,7 @@ dim.psis <- function(x) { #' @rdname psis #' @export +#' @param x For \code{is.psis}, an object to check. is.psis <- function(x) { inherits(x, "psis") && is.list(x) } diff --git a/man/psis.Rd b/man/psis.Rd index e1ecaccb..1a9d8cfc 100644 --- a/man/psis.Rd +++ b/man/psis.Rd @@ -59,6 +59,8 @@ the log scale? Defaults to \code{TRUE}.} \item{normalize}{For the \code{weights} method, should the weights be normalized? Defaults to \code{TRUE}.} + +\item{x}{For \code{is.psis}, an object to check.} } \value{ The \code{psis} methods return an object of class \code{"psis"}, From 190f79a4c3c3463be659dc5a7bad280046d3c1d8 Mon Sep 17 00:00:00 2001 From: jgabry Date: Wed, 27 Feb 2019 12:21:57 -0500 Subject: [PATCH 16/17] compare.R: remove deprecation warning for now --- R/compare.R | 2 +- tests/testthat/test_compare.R | 23 ++++++++++++++--------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/R/compare.R b/R/compare.R index d2f93b49..755462bc 100644 --- a/R/compare.R +++ b/R/compare.R @@ -60,7 +60,7 @@ #' } #' compare <- function(..., x = list()) { - .Deprecated("loo_compare") + # .Deprecated("loo_compare") dots <- list(...) if (length(dots)) { if (length(x)) { diff --git a/tests/testthat/test_compare.R b/tests/testthat/test_compare.R index 165631ce..0e6efa2d 100644 --- a/tests/testthat/test_compare.R +++ b/tests/testthat/test_compare.R @@ -73,29 +73,33 @@ test_that("loo_compare returns expected result (3 models)", { # Tests for deprecated compare() ------------------------------------------ -test_that("compare throws deprecation warnings", { - expect_warning(loo::compare(w1, w2), "Deprecated") - expect_warning(loo::compare(w1, w1, w2), "Deprecated") -}) +# test_that("compare throws deprecation warnings", { +# expect_warning(loo::compare(w1, w2), "Deprecated") +# expect_warning(loo::compare(w1, w1, w2), "Deprecated") +# }) test_that("compare returns expected result (2 models)", { - comp1 <- expect_warning(loo::compare(w1, w1), "Deprecated") + comp1 <- loo::compare(w1, w1) + # comp1 <- expect_warning(loo::compare(w1, w1), "Deprecated") expect_output(print(comp1), "elpd_diff") expect_equal(comp1[1:2], c(elpd_diff = 0, se = 0)) - comp2 <- expect_warning(loo::compare(w1, w2), "Deprecated") + comp2 <- loo::compare(w1, w2) + # comp2 <- expect_warning(loo::compare(w1, w2), "Deprecated") expect_equal_to_reference(comp2, "compare_two_models.rds") expect_named(comp2, c("elpd_diff", "se")) expect_s3_class(comp2, "compare.loo") # specifying objects via ... and via arg x gives equal results - comp_via_list <- expect_warning(loo::compare(x = list(w1, w2)), "Deprecated") + comp_via_list <- loo::compare(x = list(w1, w2)) + # comp_via_list <- expect_warning(loo::compare(x = list(w1, w2)), "Deprecated") expect_equal(comp2, comp_via_list) }) test_that("compare returns expected result (3 models)", { w3 <- SW(waic(LLarr3)) - comp1 <- expect_warning(loo::compare(w1, w2, w3), "Deprecated") + comp1 <- loo::compare(w1, w2, w3) + # comp1 <- expect_warning(loo::compare(w1, w2, w3), "Deprecated") expect_equal( colnames(comp1), @@ -111,6 +115,7 @@ test_that("compare returns expected result (3 models)", { # specifying objects via '...' gives equivalent results (equal # except rownames) to using 'x' argument - comp_via_list <- expect_warning(loo::compare(x = list(w1, w2, w3)), "Deprecated") + comp_via_list <- loo::compare(x = list(w1, w2, w3)) + # comp_via_list <- expect_warning(loo::compare(x = list(w1, w2, w3)), "Deprecated") expect_equivalent(comp1, comp_via_list) }) From 943ea2cd6d627a66be6bae1a49866e19d8f9f986 Mon Sep 17 00:00:00 2001 From: jgabry Date: Wed, 27 Feb 2019 12:24:11 -0500 Subject: [PATCH 17/17] update NEWS [ci skip] --- NEWS.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/NEWS.md b/NEWS.md index c2f17dc9..8a223357 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,8 +1,5 @@ # loo 2.0.0.9000 -* Introduced new function `loo_compare()` to eventually replace `compare()`, -which now throws a deprecation warning. (#93) - * New vignette on LOO for non-factorizable joint Gaussian models. (#75) * New `se_diff` column in model comparison results. (#78)