Skip to content

Don't library loo in tests #295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions tests/testthat/test_0_helpers.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
library(loo)

LLarr <- example_loglik_array()
LLmat <- example_loglik_matrix()

Expand All @@ -24,18 +22,26 @@ test_that("reshaping functions result in correct dimensions", {
})

test_that("reshaping functions throw correct errors", {
expect_error(llmatrix_to_array(LLmat, chain_id = rep(1:2, times = c(400, 600))),
regexp = "Not all chains have same number of iterations",
fixed = TRUE)
expect_error(llmatrix_to_array(LLmat, chain_id = rep(1:2, each = 400)),
regexp = "Number of rows in matrix not equal to length(chain_id)",
fixed = TRUE)
expect_error(llmatrix_to_array(LLmat, chain_id = rep(2:3, each = 500)),
regexp = "max(chain_id) not equal to the number of chains",
fixed = TRUE)
expect_error(llmatrix_to_array(LLmat, chain_id = rnorm(1000)),
regexp = "all(chain_id == as.integer(chain_id)) is not TRUE",
fixed = TRUE)
expect_error(
llmatrix_to_array(LLmat, chain_id = rep(1:2, times = c(400, 600))),
regexp = "Not all chains have same number of iterations",
fixed = TRUE
)
expect_error(
llmatrix_to_array(LLmat, chain_id = rep(1:2, each = 400)),
regexp = "Number of rows in matrix not equal to length(chain_id)",
fixed = TRUE
)
expect_error(
llmatrix_to_array(LLmat, chain_id = rep(2:3, each = 500)),
regexp = "max(chain_id) not equal to the number of chains",
fixed = TRUE
)
expect_error(
llmatrix_to_array(LLmat, chain_id = rnorm(1000)),
regexp = "all(chain_id == as.integer(chain_id)) is not TRUE",
fixed = TRUE
)
})

test_that("colLogMeanExps(x) = log(colMeans(exp(x))) ", {
Expand All @@ -54,9 +60,14 @@ test_that("validating log-lik objects and functions works", {
})

test_that("nlist works", {
a <- 1; b <- 2; c <- 3;
a <- 1
b <- 2
c <- 3
nlist_val <- list(nlist(a, b, c), nlist(a, b, c = "tornado"))
nlist_ans <- list(list(a = 1, b = 2, c = 3), list(a = 1, b = 2, c = "tornado"))
nlist_ans <- list(
list(a = 1, b = 2, c = 3),
list(a = 1, b = 2, c = "tornado")
)
expect_equal(nlist_val, nlist_ans)
expect_equal(nlist(a = 1, b = 2, c = 3), list(a = 1, b = 2, c = 3))
})
Expand All @@ -69,6 +80,5 @@ test_that("loo_cores works", {

options(loo.cores = 2)
expect_warning(expect_equal(loo_cores(10), 2), "deprecated")
options(loo.cores=NULL)
options(loo.cores = NULL)
})

66 changes: 51 additions & 15 deletions tests/testthat/test_E_loo.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
library(loo)

LLarr <- example_loglik_array()
LLmat <- example_loglik_matrix()
LLvec <- LLmat[, 1]
Expand All @@ -17,15 +15,54 @@ log_rats <- -LLmat
E_test_mean <- E_loo(x, psis_mat, type = "mean", log_ratios = log_rats)
E_test_var <- E_loo(x, psis_mat, type = "var", log_ratios = log_rats)
E_test_sd <- E_loo(x, psis_mat, type = "sd", log_ratios = log_rats)
E_test_quant <- E_loo(x, psis_mat, type = "quantile", probs = 0.5, log_ratios = log_rats)
E_test_quant2 <- E_loo(x, psis_mat, type = "quantile", probs = c(0.1, 0.9), log_ratios = log_rats)
E_test_quant <- E_loo(
x,
psis_mat,
type = "quantile",
probs = 0.5,
log_ratios = log_rats
)
E_test_quant2 <- E_loo(
x,
psis_mat,
type = "quantile",
probs = c(0.1, 0.9),
log_ratios = log_rats
)

# vector method
E_test_mean_vec <- E_loo(x[, 1], psis_vec, type = "mean", log_ratios = log_rats[,1])
E_test_var_vec <- E_loo(x[, 1], psis_vec, type = "var", log_ratios = log_rats[,1])
E_test_sd_vec <- E_loo(x[, 1], psis_vec, type = "sd", log_ratios = log_rats[,1])
E_test_quant_vec <- E_loo(x[, 1], psis_vec, type = "quant", probs = 0.5, log_ratios = log_rats[,1])
E_test_quant_vec2 <- E_loo(x[, 1], psis_vec, type = "quant", probs = c(0.1, 0.5, 0.9), log_ratios = log_rats[,1])
E_test_mean_vec <- E_loo(
x[, 1],
psis_vec,
type = "mean",
log_ratios = log_rats[, 1]
)
E_test_var_vec <- E_loo(
x[, 1],
psis_vec,
type = "var",
log_ratios = log_rats[, 1]
)
E_test_sd_vec <- E_loo(
x[, 1],
psis_vec,
type = "sd",
log_ratios = log_rats[, 1]
)
E_test_quant_vec <- E_loo(
x[, 1],
psis_vec,
type = "quant",
probs = 0.5,
log_ratios = log_rats[, 1]
)
E_test_quant_vec2 <- E_loo(
x[, 1],
psis_vec,
type = "quant",
probs = c(0.1, 0.5, 0.9),
log_ratios = log_rats[, 1]
)

# E_loo_khat
khat <- loo:::E_loo_khat.matrix(x, psis_mat, log_rats)
Expand Down Expand Up @@ -114,11 +151,11 @@ test_that("E_loo throws correct errors and warnings", {
# warnings
expect_no_warning(E_loo.matrix(x, psis_mat))
# no warnings if x is constant, binary, NA, NaN, Inf
expect_no_warning(E_loo.matrix(x*0, psis_mat))
expect_no_warning(E_loo.matrix(0+(x>0), psis_mat))
expect_no_warning(E_loo.matrix(x+NA, psis_mat))
expect_no_warning(E_loo.matrix(x*NaN, psis_mat))
expect_no_warning(E_loo.matrix(x*Inf, psis_mat))
expect_no_warning(E_loo.matrix(x * 0, psis_mat))
expect_no_warning(E_loo.matrix(0 + (x > 0), psis_mat))
expect_no_warning(E_loo.matrix(x + NA, psis_mat))
expect_no_warning(E_loo.matrix(x * NaN, psis_mat))
expect_no_warning(E_loo.matrix(x * Inf, psis_mat))
expect_no_warning(E_test <- E_loo.default(x[, 1], psis_vec))
expect_length(E_test$pareto_k, 1)

Expand Down Expand Up @@ -161,7 +198,6 @@ test_that("weighted quantiles work", {
quantile(xx, probs, names = FALSE)
}


set.seed(123)
pr <- seq(0.025, 0.975, 0.025)

Expand Down
114 changes: 75 additions & 39 deletions tests/testthat/test_compare.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
library(loo)
set.seed(123)

LLarr <- example_loglik_array()
Expand All @@ -12,45 +11,63 @@ test_that("loo_compare throws appropriate errors", {
w4 <- suppressWarnings(waic(LLarr[,, -(1:2)]))

expect_error(loo_compare(2, 3), "must be a list if not a 'loo' object")
expect_error(loo_compare(w1, w2, x = list(w1, w2)),
"If 'x' is a list then '...' should not be specified")
expect_error(loo_compare(w1, list(1,2,3)), "class 'loo'")
expect_error(
loo_compare(w1, w2, x = list(w1, w2)),
"If 'x' is a list then '...' should not be specified"
)
expect_error(loo_compare(w1, list(1, 2, 3)), "class 'loo'")
expect_error(loo_compare(w1), "requires at least two models")
expect_error(loo_compare(x = list(w1)), "requires at least two models")
expect_error(loo_compare(w1, w3), "same number of data points")
expect_error(loo_compare(w1, w2, w3), "same number of data points")
})

test_that("loo_compare throws appropriate warnings", {
w3 <- w1; w4 <- w2
w3 <- w1
w4 <- w2
class(w3) <- class(w4) <- c("kfold", "loo")
attr(w3, "K") <- 2
attr(w4, "K") <- 3
expect_warning(loo_compare(w3, w4), "Not all kfold objects have the same K value")
expect_warning(
loo_compare(w3, w4),
"Not all kfold objects have the same K value"
)

class(w4) <- c("psis_loo", "loo")
attr(w4, "K") <- NULL
expect_warning(loo_compare(w3, w4), "Comparing LOO-CV to K-fold-CV")

w3 <- w1; w4 <- w2
w3 <- w1
w4 <- w2
attr(w3, "yhash") <- "a"
attr(w4, "yhash") <- "b"
expect_warning(loo_compare(w3, w4), "Not all models have the same y variable")

set.seed(123)
w_list <- lapply(1:25, function(x) suppressWarnings(waic(LLarr + rnorm(1, 0, 0.1))))
expect_warning(loo_compare(w_list),
"Difference in performance potentially due to chance")

w_list_short <- lapply(1:4, function(x) suppressWarnings(waic(LLarr + rnorm(1, 0, 0.1))))
w_list <- lapply(1:25, function(x) {
suppressWarnings(waic(LLarr + rnorm(1, 0, 0.1)))
})
expect_warning(
loo_compare(w_list),
"Difference in performance potentially due to chance"
)

w_list_short <- lapply(1:4, function(x) {
suppressWarnings(waic(LLarr + rnorm(1, 0, 0.1)))
})
expect_no_warning(loo_compare(w_list_short))
})



comp_colnames <- c(
"elpd_diff", "se_diff", "elpd_waic", "se_elpd_waic",
"p_waic", "se_p_waic", "waic", "se_waic"
"elpd_diff",
"se_diff",
"elpd_waic",
"se_elpd_waic",
"p_waic",
"se_p_waic",
"waic",
"se_waic"
)

test_that("loo_compare returns expected results (2 models)", {
Expand All @@ -59,15 +76,15 @@ test_that("loo_compare returns expected results (2 models)", {
expect_equal(colnames(comp1), comp_colnames)
expect_equal(rownames(comp1), c("model1", "model2"))
expect_output(print(comp1), "elpd_diff")
expect_equal(comp1[1:2,1], c(0, 0), ignore_attr = TRUE)
expect_equal(comp1[1:2,2], c(0, 0), ignore_attr = TRUE)
expect_equal(comp1[1:2, 1], c(0, 0), ignore_attr = TRUE)
expect_equal(comp1[1:2, 2], c(0, 0), ignore_attr = TRUE)

comp2 <- loo_compare(w1, w2)
expect_s3_class(comp2, "compare.loo")
expect_equal(colnames(comp2), comp_colnames)

expect_snapshot_value(comp2, style = "serialize")

# specifying objects via ... and via arg x gives equal results
expect_equal(comp2, loo_compare(x = list(w1, w2)))
})
Expand All @@ -79,7 +96,7 @@ test_that("loo_compare returns expected result (3 models)", {

expect_equal(colnames(comp1), comp_colnames)
expect_equal(rownames(comp1), c("model1", "model2", "model3"))
expect_equal(comp1[1,1], 0)
expect_equal(comp1[1, 1], 0)
expect_s3_class(comp1, "compare.loo")
expect_s3_class(comp1, "matrix")

Expand Down Expand Up @@ -119,34 +136,53 @@ test_that("compare returns expected result (3 models)", {
expect_equal(
colnames(comp1),
c(
"elpd_diff", "se_diff", "elpd_waic", "se_elpd_waic",
"p_waic", "se_p_waic", "waic", "se_waic"
))
"elpd_diff",
"se_diff",
"elpd_waic",
"se_elpd_waic",
"p_waic",
"se_p_waic",
"waic",
"se_waic"
)
)
expect_equal(rownames(comp1), c("w1", "w2", "w3"))
expect_equal(comp1[1,1], 0)
expect_equal(comp1[1, 1], 0)
expect_s3_class(comp1, "compare.loo")
expect_s3_class(comp1, "matrix")
expect_snapshot_value(comp1, style = "serialize")

# specifying objects via '...' gives equivalent results (equal
# except rownames) to using 'x' argument
expect_warning(comp_via_list <- loo::compare(x = list(w1, w2, w3)), "Deprecated")
expect_warning(
comp_via_list <- loo::compare(x = list(w1, w2, w3)),
"Deprecated"
)
expect_equal(comp1, comp_via_list, ignore_attr = TRUE)
})

test_that("compare throws appropriate errors", {
expect_error(suppressWarnings(loo::compare(w1, w2, x = list(w1, w2))),
"should not be specified")
expect_error(suppressWarnings(loo::compare(x = 2)),
"must be a list")
expect_error(suppressWarnings(loo::compare(x = list(2))),
"should have class 'loo'")
expect_error(suppressWarnings(loo::compare(x = list(w1))),
"requires at least two models")

w3 <- suppressWarnings(waic(LLarr2[,,-1]))
expect_error(suppressWarnings(loo::compare(x = list(w1, w3))),
"same number of data points")
expect_error(suppressWarnings(loo::compare(x = list(w1, w2, w3))),
"same number of data points")
expect_error(
suppressWarnings(loo::compare(w1, w2, x = list(w1, w2))),
"should not be specified"
)
expect_error(suppressWarnings(loo::compare(x = 2)), "must be a list")
expect_error(
suppressWarnings(loo::compare(x = list(2))),
"should have class 'loo'"
)
expect_error(
suppressWarnings(loo::compare(x = list(w1))),
"requires at least two models"
)

w3 <- suppressWarnings(waic(LLarr2[,, -1]))
expect_error(
suppressWarnings(loo::compare(x = list(w1, w3))),
"same number of data points"
)
expect_error(
suppressWarnings(loo::compare(x = list(w1, w2, w3))),
"same number of data points"
)
})
1 change: 0 additions & 1 deletion tests/testthat/test_deprecated_extractors.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
library(loo)
options(mc.cores = 1)
set.seed(123)

Expand Down
2 changes: 0 additions & 2 deletions tests/testthat/test_extract_log_lik.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
library(loo)

test_that("extract_log_lik throws appropriate errors", {
x1 <- rnorm(100)
expect_error(extract_log_lik(x1), regexp = "Not a stanfit object")
Expand Down
10 changes: 4 additions & 6 deletions tests/testthat/test_gpdfit.R
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
library(loo)

test_that("gpdfit returns correct result", {
set.seed(123)
x <- rexp(100)
gpdfit_val_old <- unlist(gpdfit(x, wip=FALSE, min_grid_pts = 80))
gpdfit_val_old <- unlist(gpdfit(x, wip = FALSE, min_grid_pts = 80))
expect_snapshot_value(gpdfit_val_old, style = "serialize")

gpdfit_val_wip <- unlist(gpdfit(x, wip=TRUE, min_grid_pts = 80))
gpdfit_val_wip <- unlist(gpdfit(x, wip = TRUE, min_grid_pts = 80))
expect_snapshot_value(gpdfit_val_wip, style = "serialize")

gpdfit_val_wip_default_grid <- unlist(gpdfit(x, wip=TRUE))
gpdfit_val_wip_default_grid <- unlist(gpdfit(x, wip = TRUE))
expect_snapshot_value(gpdfit_val_wip_default_grid, style = "serialize")
})

test_that("qgpd returns the correct result ", {
probs <- seq(from = 0, to = 1, by = 0.25)
q1 <- qgpd(probs, k = 1, sigma = 1)
expect_equal(q1, c(0, 1/3, 1, 3, Inf))
expect_equal(q1, c(0, 1 / 3, 1, 3, Inf))

q2 <- qgpd(probs, k = 1, sigma = 0)
expect_true(all(is.nan(q2)))
Expand Down
Loading