From 115fdd0adf3dbf6da1fb51580455efe996e36be7 Mon Sep 17 00:00:00 2001 From: Florence Bockting Date: Tue, 30 Jun 2026 12:38:19 +0300 Subject: [PATCH 1/6] feat: use mirai and mori in do_importance_sampling --- R/importance_sampling.R | 52 +++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/R/importance_sampling.R b/R/importance_sampling.R index f8dfd4d3..59562f7e 100644 --- a/R/importance_sampling.R +++ b/R/importance_sampling.R @@ -199,27 +199,18 @@ do_importance_sampling <- function(log_ratios, r_eff, cores, method) { } if (cores == 1) { - lw_list <- lapply(seq_len(N), function(i) - is_fun(log_ratios_i = log_ratios[, i], tail_len_i = tail_len[i])) + lw_list <- lapply(seq_len(N), do_is_i, is_fun, log_ratios, tail_len) } else { - if (!os_is_windows()) { - lw_list <- parallel::mclapply( - X = seq_len(N), - mc.cores = cores, - FUN = function(i) - is_fun(log_ratios_i = log_ratios[, i], tail_len_i = tail_len[i]) - ) - } else { - cl <- parallel::makePSOCKcluster(cores) - on.exit(parallel::stopCluster(cl)) - lw_list <- - parallel::parLapply( - cl = cl, - X = seq_len(N), - fun = function(i) - is_fun(log_ratios_i = log_ratios[, i], tail_len_i = tail_len[i]) - ) - } + shared_lr <- mori::share(log_ratios) + lw_list <- with( + mirai::daemons(cores), + mirai::mirai_map( + seq_len(N), + do_is_i, + .args = list(is_fun = is_fun, log_ratios = shared_lr, + tail_len = tail_len) + )[] + ) } log_weights <- psis_apply(lw_list, "log_weights", fun_val = numeric(S)) @@ -234,3 +225,24 @@ do_importance_sampling <- function(log_ratios, r_eff, cores, method) { method = rep(method, length(pareto_k)) # Conform to other attr that exist per obs. ) } + +#' Apply an importance sampling method to a single observation +#' +#' @noRd +#' @keywords internal +#' @description +#' Worker function mapped over observations (matrix columns) by +#' [do_importance_sampling()], either serially via [lapply()] or in parallel +#' via [mirai::mirai_map()]. +#' @param i Integer column index of the observation to process. +#' @param is_fun The per-observation importance sampling function to apply, one +#' of [do_psis_i()], [do_tis_i()], or [do_sis_i()]. +#' @param log_ratios Matrix of log ratios (`-loglik`). May be a shared-memory +#' object created by [mori::share()] to avoid copying to each worker. +#' @param tail_len Vector of tail lengths used to fit the GPD, one per +#' observation. +#' @return The result of `is_fun` for observation `i` (a list with elements +#' such as `log_weights` and `pareto_k`). +do_is_i <- function(i, is_fun, log_ratios, tail_len) { + is_fun(log_ratios_i = log_ratios[, i], tail_len_i = tail_len[i]) +} \ No newline at end of file From 4b6bbc3ba1b7c4567f3d3ed4f4aff69b846746af Mon Sep 17 00:00:00 2001 From: Florence Bockting Date: Tue, 30 Jun 2026 14:33:37 +0300 Subject: [PATCH 2/6] refactor: use mirai/mori for parallelization --- .gitignore | 4 + DESCRIPTION | 2 + R/effective_sample_sizes.R | 101 ++++++-------- R/importance_sampling.R | 29 ++-- R/loo.R | 67 +++------ R/loo_model_weights.R | 27 ++-- R/loo_moment_matching.R | 87 ++++++++---- R/loo_subsample.R | 23 +-- R/parallel.R | 181 ++++++++++++++++++++++++ tests/testthat/test_parallel.R | 246 +++++++++++++++++++++++++++++++++ 10 files changed, 603 insertions(+), 164 deletions(-) create mode 100644 R/parallel.R create mode 100644 tests/testthat/test_parallel.R diff --git a/.gitignore b/.gitignore index 9c69942e..f070106a 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,7 @@ tests/testthat/Rplots.pdf cran-comments.md CRAN-RELEASE release-prep.R + +agent/* +data/* +scratch-files/* \ No newline at end of file diff --git a/DESCRIPTION b/DESCRIPTION index 29f41847..f673389a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -37,6 +37,8 @@ Depends: Imports: checkmate, matrixStats (>= 0.52), + mirai, + mori, parallel, posterior (>= 1.7.0), stats diff --git a/R/effective_sample_sizes.R b/R/effective_sample_sizes.R index 360a3098..ab565555 100644 --- a/R/effective_sample_sizes.R +++ b/R/effective_sample_sizes.R @@ -62,32 +62,29 @@ relative_eff.array <- function(x, ..., cores = getOption("mc.cores", 1)) { stopifnot(length(dim(x)) == 3) S <- prod(dim(x)[1:2]) # posterior sample size = iter * chains - if (cores == 1) { - n_eff_vec <- apply(x, 3, posterior::ess_mean) - } else { - if (!os_is_windows()) { - n_eff_list <- - parallel::mclapply( - mc.cores = cores, - X = seq_len(dim(x)[3]), - FUN = function(i) posterior::ess_mean(x[, , i, drop = TRUE]) - ) - } else { - cl <- parallel::makePSOCKcluster(cores) - on.exit(parallel::stopCluster(cl)) - n_eff_list <- - parallel::parLapply( - cl = cl, - X = seq_len(dim(x)[3]), - fun = function(i) posterior::ess_mean(x[, , i, drop = TRUE]) - ) - } - n_eff_vec <- unlist(n_eff_list, use.names = FALSE) - } + # The full draws array is reused across observations, so it is broadcast via + # shared memory on a local pool. Each worker reads only its slice `x[, , i]`. + n_eff_list <- with_loo_daemons( + cores, + loo_map( + seq_len(dim(x)[3]), + relative_eff_i_array, + cores = cores, + broadcast = list(x = x) + ) + ) + n_eff_vec <- unlist(n_eff_list, use.names = FALSE) return(n_eff_vec / S) } +#' Worker computing `ess_mean()` for a single slice of a draws array +#' @noRd +#' @keywords internal +relative_eff_i_array <- function(i, x) { + posterior::ess_mean(x[, , i, drop = TRUE]) +} + #' @export #' @templateVar fn relative_eff #' @template function @@ -104,46 +101,36 @@ relative_eff.function <- f_i <- validate_llfun(x) # not really an llfun, should return exp(ll) or exp(-ll) N <- dim(data)[1] - if (cores == 1) { - n_eff_list <- - lapply( - X = seq_len(N), - FUN = function(i) { - val_i <- f_i(data_i = data[i, , drop = FALSE], draws = draws, ...) - relative_eff.default(as.vector(val_i), chain_id = chain_id, cores = 1) - } - ) - } else { - if (!os_is_windows()) { - n_eff_list <- - parallel::mclapply( - X = seq_len(N), - FUN = function(i) { - val_i <- f_i(data_i = data[i, , drop = FALSE], draws = draws, ...) - relative_eff.default(as.vector(val_i), chain_id = chain_id, cores = 1) - }, - mc.cores = cores - ) - } else { - cl <- parallel::makePSOCKcluster(cores) - parallel::clusterExport(cl=cl, varlist=c("draws", "chain_id", "data"), envir=environment()) - on.exit(parallel::stopCluster(cl)) - n_eff_list <- - parallel::parLapply( - cl = cl, - X = seq_len(N), - fun = function(i) { - val_i <- f_i(data_i = data[i, , drop = FALSE], draws = draws, ...) - relative_eff.default(as.vector(val_i), chain_id = chain_id, cores = 1) - } - ) - } - } + # `data` and `draws` are reused for every observation, so they are broadcast + # via shared memory on a local pool and serialized on a remote pool. + n_eff_list <- with_loo_daemons( + cores, + loo_map( + seq_len(N), + relative_eff_i_function, + f_i = f_i, + chain_id = chain_id, + re_dots = list(...), + cores = cores, + broadcast = list(data = data, draws = draws) + ) + ) n_eff_vec <- unlist(n_eff_list, use.names = FALSE) return(n_eff_vec) } +#' Worker computing the relative effective sample size for observation `i` +#' @noRd +#' @keywords internal +relative_eff_i_function <- function(i, f_i, data, draws, chain_id, re_dots) { + val_i <- do.call( + f_i, + c(list(data_i = data[i, , drop = FALSE], draws = draws), re_dots) + ) + relative_eff.default(as.vector(val_i), chain_id = chain_id, cores = 1) +} + #' @export #' @describeIn relative_eff #' If `x` is an object of class `"psis"`, `relative_eff()` simply returns diff --git a/R/importance_sampling.R b/R/importance_sampling.R index 59562f7e..59d776c3 100644 --- a/R/importance_sampling.R +++ b/R/importance_sampling.R @@ -198,20 +198,23 @@ do_importance_sampling <- function(log_ratios, r_eff, cores, method) { stop("Incorrect IS method.") } - if (cores == 1) { - lw_list <- lapply(seq_len(N), do_is_i, is_fun, log_ratios, tail_len) - } else { - shared_lr <- mori::share(log_ratios) - lw_list <- with( - mirai::daemons(cores), - mirai::mirai_map( - seq_len(N), - do_is_i, - .args = list(is_fun = is_fun, log_ratios = shared_lr, - tail_len = tail_len) - )[] + # Each observation needs a different column of `log_ratios`, but the whole + # matrix is reused across the map, so it is a broadcast object: `loo_map()` + # shares it via shared memory on a local pool (zero-copy column access) and + # falls back to serialization on a remote pool. Serial work runs as a plain + # lapply(). `with_loo_daemons()` provides a pool when this is the top-level + # call (e.g. psis()) and reuses an outer pool when called from loo(). + lw_list <- with_loo_daemons( + cores, + loo_map( + seq_len(N), + do_is_i, + is_fun = is_fun, + tail_len = tail_len, + cores = cores, + broadcast = list(log_ratios = log_ratios) ) - } + ) log_weights <- psis_apply(lw_list, "log_weights", fun_val = numeric(S)) pareto_k <- psis_apply(lw_list, "pareto_k") diff --git a/R/loo.R b/R/loo.R index 10b1bdc7..6830951e 100644 --- a/R/loo.R +++ b/R/loo.R @@ -665,52 +665,23 @@ parallel_importance_sampling_list <- function(N, .loo_i, .llfun, data, draws, r_eff, save_psis, cores, method, ...){ - if (cores == 1) { - psis_list <- - lapply( - X = seq_len(N), - FUN = .loo_i, - llfun = .llfun, - data = data, - draws = draws, - r_eff = r_eff, - save_psis = save_psis, - is_method = method, - ... - ) - } else { - if (!os_is_windows()) { - # On Mac or Linux use mclapply() for multiple cores - psis_list <- - parallel::mclapply( - mc.cores = cores, - X = seq_len(N), - FUN = .loo_i, - llfun = .llfun, - data = data, - draws = draws, - r_eff = r_eff, - save_psis = save_psis, - is_method = method, - ... - ) - } else { - # On Windows use makePSOCKcluster() and parLapply() for multiple cores - cl <- parallel::makePSOCKcluster(cores) - on.exit(parallel::stopCluster(cl)) - psis_list <- - parallel::parLapply( - cl = cl, - X = seq_len(N), - fun = .loo_i, - llfun = .llfun, - data = data, - draws = draws, - r_eff = r_eff, - save_psis = save_psis, - is_method = method, - ... - ) - } - } + # `draws` (and `data`) are reused identically for every observation, so they + # are broadcast objects: shared once via shared memory on a local pool + # (recovering the copy-on-write benefit fork gave the old mclapply() path) + # and serialized on a remote pool. A single cross-platform code path replaces + # the previous mclapply()/parLapply() branching. + with_loo_daemons( + cores, + loo_map( + seq_len(N), + .loo_i, + llfun = .llfun, + r_eff = r_eff, + save_psis = save_psis, + is_method = method, + ..., + cores = cores, + broadcast = list(data = data, draws = draws) + ) + ) } diff --git a/R/loo_model_weights.R b/R/loo_model_weights.R index 946dc7c3..3eb8d63c 100644 --- a/R/loo_model_weights.R +++ b/R/loo_model_weights.R @@ -188,15 +188,24 @@ loo_model_weights.default <- N <- ncol(x[[1]]) # number of data points validate_log_lik_list(x) validate_r_eff_list(r_eff_list, K, N) - lpd_point <- matrix(NA, N, K) - elpd_loo <- rep(NA, K) - for (k in 1:K) { - r_eff_k <- r_eff_list[[k]] # possibly NULL - log_likelihood <- x[[k]] - loo_object <- loo(log_likelihood, r_eff = r_eff_k, cores = cores) - lpd_point[, k] <- loo_object$pointwise[, "elpd_loo"] #calculate log(p_k (y_i | y_-i)) - elpd_loo[k] <- loo_object$estimates["elpd_loo", "Estimate"] - } + # Establish a single daemon pool for all K models so each inner loo() + # reuses it instead of spinning a pool up and down K times. + loo_objects <- with_loo_daemons( + cores, + lapply(seq_len(K), function(k) { + loo(x[[k]], r_eff = r_eff_list[[k]], cores = cores) + }) + ) + lpd_point <- vapply( + loo_objects, + function(o) o$pointwise[, "elpd_loo"], #calculate log(p_k (y_i | y_-i)) + FUN.VALUE = numeric(N) + ) + elpd_loo <- vapply( + loo_objects, + function(o) o$estimates["elpd_loo", "Estimate"], + FUN.VALUE = numeric(1) + ) } else if (is.psis_loo(x[[1]])) { validate_psis_loo_list(x) lpd_point <- do.call(cbind, lapply(x, function(obj) obj$pointwise[, "elpd_loo"])) diff --git a/R/loo_moment_matching.R b/R/loo_moment_matching.R index 110eff93..c37f9102 100644 --- a/R/loo_moment_matching.R +++ b/R/loo_moment_matching.R @@ -111,32 +111,27 @@ loo_moment_match.default <- function(x, loo, post_draws, log_lik_i, kfs <- rep(0,N) I <- which(ks > k_threshold) - loo_moment_match_i_fun <- function(i) { - loo_moment_match_i(i = i, x = x, log_lik_i = log_lik_i, - unconstrain_pars = unconstrain_pars, - log_prob_upars = log_prob_upars, - log_lik_i_upars = log_lik_i_upars, - max_iters = max_iters, k_threshold = k_threshold, - split = split, cov = cov, N = N, S = S, upars = upars, - orig_log_prob = orig_log_prob, k = ks[i], - is_method = is_method, npars = npars, ...) - } - - if (cores == 1) { - mm_list <- lapply(X = I, FUN = function(i) loo_moment_match_i_fun(i)) - } - else { - if (!os_is_windows()) { - mm_list <- parallel::mclapply(X = I, mc.cores = cores, - FUN = function(i) loo_moment_match_i_fun(i)) - } - else { - cl <- parallel::makePSOCKcluster(cores) - on.exit(parallel::stopCluster(cl)) - mm_list <- parallel::parLapply(cl = cl, X = I, - fun = function(i) loo_moment_match_i_fun(i)) - } - } + # The large unconstrained-draws matrix `upars` and the `orig_log_prob` vector + # are reused for every high-Pareto-k observation, so they are broadcast via + # shared memory on a local pool. The worker is the namespace-level + # `loo_moment_match_i_worker()` (rather than a closure over this frame) so the + # broadcast objects are not also dragged along inside a captured environment. + mm_list <- with_loo_daemons( + cores, + loo_map( + I, + loo_moment_match_i_worker, + x = x, ks = ks, log_lik_i = log_lik_i, + unconstrain_pars = unconstrain_pars, + log_prob_upars = log_prob_upars, + log_lik_i_upars = log_lik_i_upars, + max_iters = max_iters, k_threshold = k_threshold, + split = split, cov = cov, N = N, S = S, + is_method = is_method, npars = npars, mm_dots = list(...), + cores = cores, + broadcast = list(upars = upars, orig_log_prob = orig_log_prob) + ) + ) # update results for (ii in seq_along(I)) { @@ -230,6 +225,46 @@ loo_moment_match.default <- function(x, loo, post_draws, log_lik_i, #' @param ... Further arguments passed to the custom functions documented above. #' @return List with the updated elpd values and diagnostics #' +#' Worker wrapper around [loo_moment_match_i()] for parallel mapping +#' +#' @noRd +#' @keywords internal +#' @description +#' A namespace-level (non-closure) adapter mapped over high-Pareto-k +#' observation indices by [loo_map()]. Keeping it at namespace scope means it +#' does not capture the calling frame, so large objects shared via +#' [mori::share()] (`upars`, `orig_log_prob`) are not duplicated inside a +#' serialized closure environment. The per-observation Pareto k is selected +#' here from the full `ks` vector, and any extra arguments are forwarded +#' through `mm_dots`. +#' @param i Integer observation index. +#' @param ks Full vector of Pareto k estimates; `ks[i]` is used for this fold. +#' @param mm_dots A list of additional arguments forwarded to +#' [loo_moment_match_i()] (the `...` from [loo_moment_match()]). +#' @return The result of [loo_moment_match_i()] for observation `i`. +loo_moment_match_i_worker <- function(i, x, ks, log_lik_i, unconstrain_pars, + log_prob_upars, log_lik_i_upars, + max_iters, k_threshold, split, cov, + N, S, upars, orig_log_prob, is_method, + npars, mm_dots) { + do.call( + loo_moment_match_i, + c( + list( + i = i, x = x, log_lik_i = log_lik_i, + unconstrain_pars = unconstrain_pars, + log_prob_upars = log_prob_upars, + log_lik_i_upars = log_lik_i_upars, + max_iters = max_iters, k_threshold = k_threshold, + split = split, cov = cov, N = N, S = S, upars = upars, + orig_log_prob = orig_log_prob, k = ks[i], + is_method = is_method, npars = npars + ), + mm_dots + ) + ) +} + loo_moment_match_i <- function(i, x, log_lik_i, diff --git a/R/loo_subsample.R b/R/loo_subsample.R index bcac4b17..5912d993 100644 --- a/R/loo_subsample.R +++ b/R/loo_subsample.R @@ -494,17 +494,18 @@ lpd_i <- function(i, llfun, data, draws) { #' @noRd #' @return a vector of computed log probability densities compute_lpds <- function(N, data, draws, llfun, cores) { - if (cores == 1) { - lpds <- lapply(X = seq_len(N), FUN = lpd_i, llfun, data, draws) - } else { - if (.Platform$OS.type != "windows") { - lpds <- mclapply(X = seq_len(N), mc.cores = cores, FUN = lpd_i, llfun, data, draws) - } else { - cl <- makePSOCKcluster(cores) - on.exit(stopCluster(cl)) - lpds <- parLapply(cl, X = seq_len(N), fun = lpd_i, llfun, data, draws) - } - } + # `draws` (and `data`) are reused for every observation, so they are shared + # once via shared memory on a local pool and serialized on a remote pool. + lpds <- with_loo_daemons( + cores, + loo_map( + seq_len(N), + lpd_i, + llfun = llfun, + cores = cores, + broadcast = list(data = data, draws = draws) + ) + ) unlist(lpds) } diff --git a/R/parallel.R b/R/parallel.R new file mode 100644 index 00000000..4c8fbe3c --- /dev/null +++ b/R/parallel.R @@ -0,0 +1,181 @@ +#' Evaluate parallel work with an appropriate mirai daemon pool +#' +#' @noRd +#' @keywords internal +#' @description +#' Central entry point used by loo's parallel code paths to ensure a +#' [mirai::daemons()] pool exists for the duration of a computation. It is +#' deliberately a good citizen of the user's session: +#' +#' * `cores <= 1`: runs `code` serially without touching daemons. +#' * A daemon pool is already configured (e.g. the user called +#' [mirai::daemons()] themselves, possibly with remote/HPC daemons): `code` +#' runs on the existing pool, which is left untouched. +#' * Otherwise: a pool of `cores` local daemons is created for the duration of +#' `code` and automatically reset afterwards (via the scoped +#' `with(mirai::daemons(), ...)` method), so no daemon processes are left +#' running once the call returns. +#' +#' This keeps a single pool alive across the whole top-level computation +#' (rather than spinning daemons up and down for each unit of work) while +#' respecting any pool the user has already declared. Because it reuses an +#' existing pool, it is safe to nest: an inner call made while an outer call +#' already established a pool simply reuses it instead of creating another. +#' +#' @param cores Integer number of cores requested by the user. +#' @param code Expression to evaluate. Lazily evaluated in the calling +#' environment, after any daemon pool has been set up. +#' @return The value of `code`. +with_loo_daemons <- function(cores, code) { + if (cores <= 1 || loo_has_pool()) { + # Serial work, or reuse the daemon pool the user (or an outer loo call) + # already configured. + return(code) + } + # No pool configured: create one scoped to this computation and reset it on + # exit. `code` (including result collection via `[]`) is forced before the + # daemons are torn down. + with(mirai::daemons(cores), code) +} + +#' Is a mirai daemon pool currently connected? +#' +#' @noRd +#' @keywords internal +#' @return `TRUE` if at least one daemon connection exists for the active +#' compute profile, otherwise `FALSE`. +loo_has_pool <- function() { + conns <- tryCatch(mirai::status()$connections, error = function(e) 0L) + isTRUE(as.integer(conns) > 0L) +} + +#' Number of workers available for chunking decisions +#' +#' @noRd +#' @keywords internal +#' @description +#' Returns the number of connected daemons when a pool exists (so chunking +#' matches the actual worker count, including user-supplied or remote pools), +#' otherwise falls back to the requested `cores`. +#' @param cores Integer number of cores requested by the user. +loo_n_workers <- function(cores) { + conns <- tryCatch(mirai::status()$connections, error = function(e) 0L) + conns <- as.integer(conns) + if (length(conns) != 1L || is.na(conns) || conns < 1L) { + return(as.integer(cores)) + } + conns +} + +#' Is the active daemon pool on the local machine? +#' +#' @noRd +#' @keywords internal +#' @description +#' Determines whether shared memory ([mori::share()]) can be used safely with +#' the active pool. Shared memory only works when workers run on the same +#' physical machine, so we only treat same-host transports as local: +#' +#' * `abstract://` and `ipc://` are same-machine inter-process transports used +#' by local [mirai::daemons()] pools, so these are treated as local. +#' * `tcp://` (and anything else) may be a remote pool, or the host URL that +#' remote SSH/HPC daemons dial back to, so it is treated as **not** local. +#' loo then falls back to ordinary serialization instead of shared memory. +#' +#' This is intentionally conservative: an incorrect "local" classification +#' would produce wrong results on a remote pool, whereas an incorrect "remote" +#' classification merely forgoes the zero-copy optimisation. +#' @return `TRUE` if the pool is confirmed local, otherwise `FALSE`. +loo_pool_is_local <- function() { + urls <- tryCatch(mirai::status()$daemons, error = function(e) NULL) + if (!is.character(urls) || length(urls) == 0L) { + return(FALSE) + } + all(grepl("^(abstract|ipc)://", urls)) +} + +#' Map a worker over elements, serially or across a mirai daemon pool +#' +#' @noRd +#' @keywords internal +#' @description +#' Single cross-platform entry point for loo's per-observation parallelism. +#' Replaces the previous platform-branching +#' [parallel::mclapply()] / [parallel::parLapply()] code paths with a single +#' [mirai::mirai_map()] path, while preserving the serial [lapply()] behaviour +#' when no parallelism is requested or available. +#' +#' Object transport is chosen automatically: +#' +#' * `broadcast` objects are reused identically by every element (e.g. the +#' posterior `draws` matrix). On a local pool they are written once into +#' shared memory with [mori::share()] so each daemon maps the same physical +#' pages (zero-copy). On a remote pool, where shared memory is unavailable, +#' they are serialized instead; chunking bounds the number of copies sent to +#' roughly one per worker. +#' * Small per-call arguments are passed through `...`. +#' +#' @param X A vector or list to iterate over. Each element is passed as the +#' first argument to `FUN`. +#' @param FUN Worker function. Called as `FUN(x, , <...>)`; the +#' names in `broadcast` and `...` must match `FUN`'s formals. +#' @param ... Small constant arguments forwarded to `FUN` for every element. +#' @param cores Integer number of cores requested by the user. Parallelism is +#' only used when `cores > 1` and a daemon pool is connected. +#' @param broadcast Named list of large objects reused by every element. See +#' Description for how these are transported. +#' @param chunk Chunking strategy. `"auto"` (default) splits `X` into roughly +#' one chunk per worker to amortise per-task overhead -- best for cheap +#' per-element work over many elements. `"never"` dispatches one task per +#' element for finer load balancing -- best for expensive, uneven per-element +#' work. `"never"` is automatically promoted to `"auto"` on a remote pool +#' that carries `broadcast` objects, to avoid re-sending them per task. +#' @return A list of `FUN` results in the same order as `X`. +loo_map <- function(X, FUN, ..., cores = 1L, broadcast = list(), + chunk = c("auto", "never")) { + chunk <- match.arg(chunk) + dots <- list(...) + + if (!(cores > 1L && loo_has_pool())) { + # Serial path: identical behaviour to a plain lapply() with the broadcast + # and constant arguments supplied by name. + return(do.call(lapply, c(list(X, FUN), broadcast, dots))) + } + + local_pool <- loo_pool_is_local() + if (length(broadcast) > 0L) { + if (local_pool) { + # Zero-copy: write once to shared memory, ship tiny references. + broadcast <- lapply(broadcast, mori::share) + } else if (chunk == "never") { + # Remote pool: avoid re-serializing large broadcast objects once per + # task by collapsing to one chunk per worker instead. + chunk <- "auto" + } + } + const_args <- c(broadcast, dots) + + if (chunk == "never") { + return( + mirai::mirai_map( + X, + function(.x, .FUN, .const) do.call(.FUN, c(list(.x), .const)), + .args = list(.FUN = FUN, .const = const_args) + )[mirai::.stop] + ) + } + + n_chunks <- min(loo_n_workers(cores), length(X)) + positions <- parallel::splitIndices(length(X), n_chunks) + chunks <- lapply(positions, function(p) X[p]) + chunk_results <- mirai::mirai_map( + chunks, + function(.chunk, .FUN, .const) { + lapply(.chunk, function(.x) do.call(.FUN, c(list(.x), .const))) + }, + .args = list(.FUN = FUN, .const = const_args) + )[mirai::.stop] + # splitIndices() returns contiguous ascending groups, so concatenating the + # per-chunk lists restores the original order of X. + do.call(c, chunk_results) +} diff --git a/tests/testthat/test_parallel.R b/tests/testthat/test_parallel.R new file mode 100644 index 00000000..dc887c04 --- /dev/null +++ b/tests/testthat/test_parallel.R @@ -0,0 +1,246 @@ +options(mc.cores = 1) +set.seed(123) + +# Make sure no daemon pool leaks in from another test file. +mirai::daemons(0) + +LLarr <- example_loglik_array() +LLmat <- example_loglik_matrix() +chain_id <- rep(1:2, each = dim(LLarr)[1]) +r_eff <- relative_eff(exp(LLarr)) + +# Shared data for the function-method end-to-end checks. +set.seed(1) +S_fn <- 200 +N_fn <- 30 +draws_fn <- cbind(mu = rnorm(S_fn), sigma = abs(rnorm(S_fn)) + 0.5) +data_fn <- data.frame(y = rnorm(N_fn)) +llfun_test <- function(data_i, draws, ...) { + dnorm(data_i$y, mean = draws[, "mu"], sd = draws[, "sigma"], log = TRUE) +} + + +# Pool-introspection helpers ------------------------------------------------- + +test_that("loo_has_pool() and loo_pool_is_local() detect a local pool", { + mirai::daemons(0) + expect_false(loo:::loo_has_pool()) + expect_false(loo:::loo_pool_is_local()) + + skip_on_cran() + mirai::daemons(2) + on.exit(mirai::daemons(0), add = TRUE) + expect_true(loo:::loo_has_pool()) + expect_true(loo:::loo_pool_is_local()) + # Chunking uses the connected daemon count, not the requested cores. + expect_equal(loo:::loo_n_workers(1), 2L) +}) + +test_that("loo_pool_is_local() is FALSE for a tcp pool (remote-safety gate)", { + skip_on_cran() + mirai::daemons(0) + mirai::daemons(n = 2, url = mirai::local_url(tcp = TRUE)) + on.exit(mirai::daemons(0), add = TRUE) + # The locality gate reads the configured transport URL (available + # immediately, regardless of connection timing). tcp:// may be a remote/SSH + # pool, so shared memory must not be assumed. + expect_false(loo:::loo_pool_is_local()) +}) + + +# loo_map() ------------------------------------------------------------------ + +test_that("loo_map() runs serially when no pool is available", { + mirai::daemons(0) + res <- loo:::loo_map(1:5, function(x, m) x * m, m = 2, cores = 4) + expect_identical(res, as.list((1:5) * 2)) +}) + +test_that("loo_map() runs serially when cores <= 1 even with a pool", { + skip_on_cran() + mirai::daemons(2) + on.exit(mirai::daemons(0), add = TRUE) + res <- loo:::loo_map(1:5, function(x, m) x * m, m = 3, cores = 1) + expect_identical(res, as.list((1:5) * 3)) +}) + +test_that("loo_map() parallel matches serial and preserves order", { + skip_on_cran() + worker <- function(i, mat, add) sum(mat[, i]) + add + mat <- matrix(as.numeric(1:60), nrow = 6) # 6 x 10 + N <- ncol(mat) + expected <- lapply(seq_len(N), worker, mat = mat, add = 100) + + mirai::daemons(3) + on.exit(mirai::daemons(0), add = TRUE) + + # broadcast object shared via mori on a local pool; both chunk strategies + res_auto <- loo:::loo_map( + seq_len(N), worker, add = 100, cores = 3, + broadcast = list(mat = mat), chunk = "auto" + ) + res_never <- loo:::loo_map( + seq_len(N), worker, add = 100, cores = 3, + broadcast = list(mat = mat), chunk = "never" + ) + expect_identical(res_auto, expected) + expect_identical(res_never, expected) +}) + +test_that("loo_map() works when there are more workers than elements", { + skip_on_cran() + mirai::daemons(4) + on.exit(mirai::daemons(0), add = TRUE) + res <- loo:::loo_map(1:2, function(x) x + 1L, cores = 4) + expect_identical(res, list(2L, 3L)) +}) + +test_that("loo_map() propagates worker errors", { + skip_on_cran() + mirai::daemons(2) + on.exit(mirai::daemons(0), add = TRUE) + expect_error( + loo:::loo_map(1:4, function(x) if (x == 3L) stop("boom") else x, cores = 2), + "boom" + ) +}) + + +# End-to-end: importance sampling -------------------------------------------- + +test_that("psis() parallel equals serial", { + skip_on_cran() + ps_serial <- suppressWarnings(psis(-LLmat, r_eff = r_eff, cores = 1)) + + mirai::daemons(2) + on.exit(mirai::daemons(0), add = TRUE) + ps_parallel <- suppressWarnings(psis(-LLmat, r_eff = r_eff, cores = 2)) + + expect_equal(ps_serial$log_weights, ps_parallel$log_weights) + expect_equal(ps_serial$diagnostics, ps_parallel$diagnostics) +}) + +test_that("tis() and sis() parallel equal serial", { + skip_on_cran() + tis_serial <- suppressWarnings(tis(-LLmat, r_eff = r_eff, cores = 1)) + sis_serial <- suppressWarnings(sis(-LLmat, r_eff = r_eff, cores = 1)) + + mirai::daemons(2) + on.exit(mirai::daemons(0), add = TRUE) + tis_parallel <- suppressWarnings(tis(-LLmat, r_eff = r_eff, cores = 2)) + sis_parallel <- suppressWarnings(sis(-LLmat, r_eff = r_eff, cores = 2)) + + expect_equal(tis_serial$log_weights, tis_parallel$log_weights) + expect_equal(sis_serial$log_weights, sis_parallel$log_weights) +}) + + +# End-to-end: loo() function method (broadcast draws/data) ------------------- + +test_that("loo.function parallel equals serial", { + skip_on_cran() + loo_serial <- suppressWarnings( + loo(llfun_test, data = data_fn, draws = draws_fn, cores = 1) + ) + + mirai::daemons(2) + on.exit(mirai::daemons(0), add = TRUE) + loo_parallel <- suppressWarnings( + loo(llfun_test, data = data_fn, draws = draws_fn, cores = 2) + ) + + expect_equal(loo_serial$pointwise, loo_parallel$pointwise) + expect_equal(loo_serial$estimates, loo_parallel$estimates) +}) + +test_that("loo.function reuses an existing (user-configured) pool", { + skip_on_cran() + loo_serial <- suppressWarnings( + loo(llfun_test, data = data_fn, draws = draws_fn, cores = 1) + ) + + # User sets up the pool themselves; loo should reuse it untouched. + mirai::daemons(3) + on.exit(mirai::daemons(0), add = TRUE) + loo_reuse <- suppressWarnings( + loo(llfun_test, data = data_fn, draws = draws_fn, cores = 2) + ) + # Pool is still alive after the call (loo did not tear it down). + expect_true(loo:::loo_has_pool()) + expect_equal(loo_serial$pointwise, loo_reuse$pointwise) +}) + + +# End-to-end: relative_eff --------------------------------------------------- + +test_that("relative_eff() array and function methods are parallel-invariant", { + skip_on_cran() + re_arr_serial <- relative_eff(exp(LLarr), cores = 1) + re_fn_serial <- relative_eff( + llfun_test, chain_id = rep(1, S_fn), + data = data_fn, draws = draws_fn, cores = 1 + ) + + mirai::daemons(2) + on.exit(mirai::daemons(0), add = TRUE) + re_arr_parallel <- relative_eff(exp(LLarr), cores = 2) + re_fn_parallel <- relative_eff( + llfun_test, chain_id = rep(1, S_fn), + data = data_fn, draws = draws_fn, cores = 2 + ) + + expect_equal(re_arr_serial, re_arr_parallel) + expect_equal(re_fn_serial, re_fn_parallel) +}) + + +# End-to-end: loo_subsample -------------------------------------------------- + +test_that("loo_subsample() parallel equals serial", { + skip_on_cran() + # Reset RNG before each call so the same subsample is drawn. + set.seed(4242) + ss_serial <- suppressWarnings(loo_subsample( + llfun_test, data = data_fn, draws = draws_fn, + observations = 20, loo_approximation = "plpd", cores = 1 + )) + + mirai::daemons(2) + on.exit(mirai::daemons(0), add = TRUE) + set.seed(4242) + ss_parallel <- suppressWarnings(loo_subsample( + llfun_test, data = data_fn, draws = draws_fn, + observations = 20, loo_approximation = "plpd", cores = 2 + )) + + expect_equal(ss_serial$estimates, ss_parallel$estimates) + expect_equal(ss_serial$pointwise, ss_parallel$pointwise) +}) + + +# End-to-end: loo_model_weights (single pool across K models) ---------------- + +test_that("loo_model_weights() parallel equals serial", { + skip_on_cran() + set.seed(11) + ll_list <- list( + matrix(rnorm(200 * 25), nrow = 200), + matrix(rnorm(200 * 25), nrow = 200), + matrix(rnorm(200 * 25), nrow = 200) + ) + wts_serial <- suppressWarnings( + loo_model_weights(ll_list, method = "stacking", cores = 1) + ) + + mirai::daemons(2) + on.exit(mirai::daemons(0), add = TRUE) + wts_parallel <- suppressWarnings( + loo_model_weights(ll_list, method = "stacking", cores = 2) + ) + + expect_equal(as.numeric(wts_serial), as.numeric(wts_parallel), + tolerance = 1e-6) +}) + +# Final safety net in case any test above exited early with a live pool. +mirai::daemons(0) From d7ef7bc32236c7b9b4f9da5c0f4d46f99f89a919 Mon Sep 17 00:00:00 2001 From: Florence Bockting Date: Tue, 30 Jun 2026 16:50:20 +0300 Subject: [PATCH 3/6] update: docs, benchmark, vignette --- .Rbuildignore | 1 + .gitignore | 3 +- DESCRIPTION | 1 + R/parallel.R | 124 +++++++++++++++- _pkgdown.yml | 1 + benchmark/README.md | 136 ++++++++++++++++++ benchmark/bench-comparison.md | 52 +++++++ benchmark/benchmark-parallel.R | 123 ++++++++++++++++ benchmark/compare.R | 218 ++++++++++++++++++++++++++++ benchmark/peak-mem-run.R | 47 ++++++ benchmark/peak-mem.sh | 52 +++++++ man-roxygen/cores.R | 18 +++ man/ap_psis.Rd | 20 +++ man/importance_sampling.Rd | 19 +++ man/loo.Rd | 20 +++ man/loo_approximate_posterior.Rd | 20 +++ man/loo_model_weights.Rd | 20 +++ man/loo_moment_match.Rd | 20 +++ man/loo_moment_match_split.Rd | 20 +++ man/loo_subsample.Rd | 20 +++ man/parallel_psis_list.Rd | 20 +++ man/psis.Rd | 22 ++- man/psis_approximate_posterior.Rd | 20 +++ man/sis.Rd | 22 ++- man/tis.Rd | 22 ++- man/update.psis_loo_ss.Rd | 20 +++ tests/testthat/test_parallel.R | 113 +++++++++++++++ vignettes/loo2-parallel.Rmd | 230 ++++++++++++++++++++++++++++++ 28 files changed, 1397 insertions(+), 7 deletions(-) create mode 100644 benchmark/README.md create mode 100644 benchmark/bench-comparison.md create mode 100644 benchmark/benchmark-parallel.R create mode 100644 benchmark/compare.R create mode 100644 benchmark/peak-mem-run.R create mode 100755 benchmark/peak-mem.sh create mode 100644 vignettes/loo2-parallel.Rmd diff --git a/.Rbuildignore b/.Rbuildignore index 63463781..f23c768d 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -1,3 +1,4 @@ +^benchmark$ ^CRAN-RELEASE$ ^.*\.Rproj$ ^\.Rproj\.user$ diff --git a/.gitignore b/.gitignore index f070106a..55337eba 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,5 @@ release-prep.R agent/* data/* -scratch-files/* \ No newline at end of file +scratch-files/* +notes/* \ No newline at end of file diff --git a/DESCRIPTION b/DESCRIPTION index f673389a..0b33f259 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -64,3 +64,4 @@ LazyData: TRUE Roxygen: list(markdown = TRUE) SystemRequirements: pandoc (>= 1.12.3), pandoc-citeproc Config/roxygen2/version: 8.0.0 +RoxygenNote: 7.3.3 diff --git a/R/parallel.R b/R/parallel.R index 4c8fbe3c..e7ff0f0b 100644 --- a/R/parallel.R +++ b/R/parallel.R @@ -1,3 +1,104 @@ +#' Package-internal state for the parallel backend +#' +#' @noRd +#' @keywords internal +#' @description +#' Holds small bits of session-scoped state used by the parallel helpers: +#' +#' * `cleanup_registered`: guards `loo_register_daemon_cleanup()` so the +#' session-exit finalizer is only registered once. +#' * `warned_bad_daemons`: guards the malformed-config warning in +#' `loo_persist_config()` so it is only emitted once per session. +#' +#' It also serves as the object the daemon-cleanup finalizer is attached to. +.loo_internal <- new.env(parent = emptyenv()) + +#' Resolve the persistent local daemon pool size from user configuration +#' +#' @noRd +#' @keywords internal +#' @description +#' Reads the opt-in "persistent local pool" size from, in precedence order: +#' +#' 1. the R option `loo.daemons`, +#' 2. the environment variable `LOO_DAEMONS`, +#' 3. otherwise the feature is off. +#' +#' This knob enables a local [mirai::daemons()] pool that is created lazily on +#' first parallel use and kept warm for the rest of the session (see +#' `with_loo_daemons()`), which avoids paying pool spawn/teardown overhead on +#' every top-level `loo()`/`psis()` call (useful for simulations, benchmarks +#' and batch/HPC scripts). +#' +#' @return A single integer `>= 2` giving the persistent pool size, or +#' `NA_integer_` when the feature is off (unset, `0`/`1`, or a non-integer +#' value). Genuinely malformed (non-coercible) values warn once per session +#' and then disable the feature. +loo_persist_config <- function() { + raw <- getOption("loo.daemons") + if (is.null(raw)) { + raw <- Sys.getenv("LOO_DAEMONS", unset = NA_character_) + } + if (length(raw) != 1L) { + return(NA_integer_) + } + if (is.na(raw) || (is.character(raw) && !nzchar(trimws(raw)))) { + # Unset / empty -> feature off. + return(NA_integer_) + } + n <- suppressWarnings(as.numeric(raw)) + if (is.na(n) || !is.finite(n)) { + # Non-numeric garbage -> off, but tell the user once that it was ignored. + loo_warn_bad_daemons(raw) + return(NA_integer_) + } + if (n < 2 || n != trunc(n)) { + # 0/1 (serial) or a non-integer value -> feature off, silently. + return(NA_integer_) + } + as.integer(n) +} + +#' Warn (once per session) about a malformed persistent-pool configuration +#' +#' @noRd +#' @keywords internal +loo_warn_bad_daemons <- function(value) { + if (isTRUE(.loo_internal$warned_bad_daemons)) { + return(invisible(NULL)) + } + .loo_internal$warned_bad_daemons <- TRUE + warning( + "Ignoring invalid persistent-pool size ", encodeString(value, quote = "'"), + " from 'loo.daemons'/'LOO_DAEMONS'; expected a single integer >= 2.", + call. = FALSE + ) + invisible(NULL) +} + +#' Register a one-time session-exit cleanup for the persistent daemon pool +#' +#' @noRd +#' @keywords internal +#' @description +#' Attaches a finalizer (only once per session) that resets any local daemon +#' pool with `mirai::daemons(0)` when the R session exits. mirai already +#' terminates local daemons when the host session ends; this is a +#' belt-and-suspenders guard so the lazily created persistent pool never leaves +#' orphan processes behind in batch/HPC scripts. +loo_register_daemon_cleanup <- function() { + if (isTRUE(.loo_internal$cleanup_registered)) { + return(invisible(NULL)) + } + .loo_internal$cleanup_registered <- TRUE + reg.finalizer( + .loo_internal, + function(e) try(mirai::daemons(0), silent = TRUE), + onexit = TRUE + ) + invisible(NULL) +} + #' Evaluate parallel work with an appropriate mirai daemon pool #' #' @noRd @@ -10,7 +111,14 @@ #' * `cores <= 1`: runs `code` serially without touching daemons. #' * A daemon pool is already configured (e.g. the user called #' [mirai::daemons()] themselves, possibly with remote/HPC daemons): `code` -#' runs on the existing pool, which is left untouched. +#' runs on the existing pool, which is left untouched. This always takes +#' precedence over the options below. +#' * Otherwise, if the user opted in to a persistent session pool via the +#' `loo.daemons` option or `LOO_DAEMONS` environment variable (see +#' `loo_persist_config()`): a local pool of that size is created lazily on +#' this first parallel call and left warm for the rest of the session, with a +#' session-exit finalizer registered for cleanup. Subsequent calls reuse it +#' via the existing-pool branch above. #' * Otherwise: a pool of `cores` local daemons is created for the duration of #' `code` and automatically reset afterwards (via the scoped #' `with(mirai::daemons(), ...)` method), so no daemon processes are left @@ -22,14 +130,24 @@ #' existing pool, it is safe to nest: an inner call made while an outer call #' already established a pool simply reuses it instead of creating another. #' -#' @param cores Integer number of cores requested by the user. +#' @param cores Integer number of cores requested by the user. Acts as the +#' per-call "enable parallel" switch; the persistent pool size, when enabled, +#' comes from `loo_persist_config()` rather than from `cores`. #' @param code Expression to evaluate. Lazily evaluated in the calling #' environment, after any daemon pool has been set up. #' @return The value of `code`. with_loo_daemons <- function(cores, code) { if (cores <= 1 || loo_has_pool()) { # Serial work, or reuse the daemon pool the user (or an outer loo call) - # already configured. + # already configured. This always wins over the persistent-pool option. + return(code) + } + persist <- loo_persist_config() + if (!is.na(persist)) { + # Opt-in persistent pool: create once, leave warm for the session, and + # register a finalizer to tidy up at session exit. No per-call teardown. + mirai::daemons(persist) + loo_register_daemon_cleanup() return(code) } # No pool configured: create one scoped to this computation and reset it on diff --git a/_pkgdown.yml b/_pkgdown.yml index 0a216a02..4878e8c2 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -52,6 +52,7 @@ articles: - loo2-non-factorized - loo2-lfo - loo2-large-data + - loo2-parallel - loo2-moment-matching - loo2-mixis - title: Frequently asked questions diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 00000000..40174ff0 --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,136 @@ +# loo parallel benchmarks + +These scripts measure the performance of loo's parallel code paths and compare +two installed versions of the package side by side: + +- **`baseline`** — a pre-`mirai` version (the old `mclapply`/`parLapply` + backend), e.g. the released version from CRAN. +- **`new`** — the current working tree (the `mirai` + `mori` backend, including + the persistent session pool controlled by `options(loo.daemons = k)` / + `LOO_DAEMONS`). + +The same user-facing calls (`psis()`, `loo()`) are timed for every version; only +the internal parallel backend differs. For the `new` version we additionally +time a **persist** mode that opts in to the persistent session pool, so we can +separate per-call daemon spawn/teardown overhead from the steady-state cost. + +## Files + +| File | Purpose | +|---|---| +| `benchmark-parallel.R` | Times `psis()`/`loo()` across cores for one installed version; writes `/tmp/bench-