diff --git a/NEWS.md b/NEWS.md index dda1466a56..96a5333c06 100644 --- a/NEWS.md +++ b/NEWS.md @@ -118,6 +118,10 @@ the limits of the scale and ignore the order of any `breaks` provided. Note that this may change the appearance of plots that previously relied on the unordered behaviour (#2429, @idno0001). + +* `stat_summary()` and related functions now support rlang-style lambda functions + (#3568, @dkahle). + # ggplot2 3.2.1 diff --git a/R/stat-summary-2d.r b/R/stat-summary-2d.r index 2fafaac418..3808b1d89f 100644 --- a/R/stat-summary-2d.r +++ b/R/stat-summary-2d.r @@ -30,11 +30,13 @@ #' #' # Specifying function #' d + stat_summary_2d(fun = function(x) sum(x^2)) +#' d + stat_summary_2d(fun = ~ sum(.x^2)) #' d + stat_summary_2d(fun = var) #' d + stat_summary_2d(fun = "quantile", fun.args = list(probs = 0.1)) #' #' if (requireNamespace("hexbin")) { #' d + stat_summary_hex() +#' d + stat_summary_hex(fun = ~ sum(.x^2)) #' } stat_summary_2d <- function(mapping = NULL, data = NULL, geom = "tile", position = "identity", @@ -98,6 +100,7 @@ StatSummary2d <- ggproto("StatSummary2d", Stat, xbin <- cut(data$x, xbreaks, include.lowest = TRUE, labels = FALSE) ybin <- cut(data$y, ybreaks, include.lowest = TRUE, labels = FALSE) + fun <- as_function(fun) f <- function(x) { do.call(fun, c(list(quote(x)), fun.args)) } diff --git a/R/stat-summary-bin.R b/R/stat-summary-bin.R index 811f598faa..98a13a3346 100644 --- a/R/stat-summary-bin.R +++ b/R/stat-summary-bin.R @@ -96,7 +96,7 @@ make_summary_fun <- function(fun.data, fun, fun.max, fun.min, fun.args) { if (!is.null(fun.data)) { # Function that takes complete data frame as input - fun.data <- match.fun(fun.data) + fun.data <- as_function(fun.data) function(df) { do.call(fun.data, c(list(quote(df$y)), fun.args)) } @@ -105,6 +105,7 @@ make_summary_fun <- function(fun.data, fun, fun.max, fun.min, fun.args) { call_f <- function(fun, x) { if (is.null(fun)) return(NA_real_) + fun <- as_function(fun) do.call(fun, c(list(quote(x)), fun.args)) } @@ -116,7 +117,7 @@ make_summary_fun <- function(fun.data, fun, fun.max, fun.min, fun.args) { )) } } else { - message("No summary function supplied, defaulting to `mean_se()") + message("No summary function supplied, defaulting to `mean_se()`") function(df) { mean_se(df$y) } diff --git a/R/stat-summary-hex.r b/R/stat-summary-hex.r index fe29b7b3be..6ba1d6822b 100644 --- a/R/stat-summary-hex.r +++ b/R/stat-summary-hex.r @@ -46,6 +46,7 @@ StatSummaryHex <- ggproto("StatSummaryHex", Stat, try_require("hexbin", "stat_summary_hex") binwidth <- binwidth %||% hex_binwidth(bins, scales) + fun <- as_function(fun) hexBinSummarise(data$x, data$y, data$z, binwidth, fun = fun, fun.args = fun.args, drop = drop) } diff --git a/man/stat_summary_2d.Rd b/man/stat_summary_2d.Rd index f51109717c..e03960c16d 100644 --- a/man/stat_summary_2d.Rd +++ b/man/stat_summary_2d.Rd @@ -125,11 +125,13 @@ d + stat_summary_2d() # Specifying function d + stat_summary_2d(fun = function(x) sum(x^2)) +d + stat_summary_2d(fun = ~ sum(.x^2)) d + stat_summary_2d(fun = var) d + stat_summary_2d(fun = "quantile", fun.args = list(probs = 0.1)) if (requireNamespace("hexbin")) { d + stat_summary_hex() +d + stat_summary_hex(fun = ~ sum(.x^2)) } } \seealso{ diff --git a/tests/testthat/test-stat-summary.R b/tests/testthat/test-stat-summary.R new file mode 100644 index 0000000000..1c4b29a166 --- /dev/null +++ b/tests/testthat/test-stat-summary.R @@ -0,0 +1,84 @@ +context("stat_summary") + +test_that("stat_summary(_bin) work with lambda expressions", { + # note: stat_summary and stat_summary_bin both use + # make_summary_fun, so this tests both + + dat <- data_frame( + x = c(1, 1, 2, 2, 3, 3), + y = c(0, 2, 1, 3, 2, 4) + ) + + p1 <- ggplot(dat, aes(x, y)) + + stat_summary(fun.data = mean_se) + + + # test fun.data + p2 <- ggplot(dat, aes(x, y)) + + stat_summary(fun.data = ~ { + mean <- mean(.x) + se <- sqrt(stats::var(.x) / length(.x)) + data_frame(y = mean, ymin = mean - se, ymax = mean + se) + }) + + expect_equal( + layer_data(p1), + layer_data(p2) + ) + + + # fun, fun.min, fun.max + p3 <- ggplot(dat, aes(x, y)) + + stat_summary( + fun = ~ mean(.x), + fun.min = ~ mean(.x) - sqrt(stats::var(.x) / length(.x)), + fun.max = ~ mean(.x) + sqrt(stats::var(.x) / length(.x)) + ) + + expect_equal( + layer_data(p1), + layer_data(p3) + ) + +}) + + + + +test_that("stat_summary_(2d|hex) work with lambda expressions", { + + dat <- data_frame( + x = c(0, 0, 0, 0, 1, 1, 1, 1), + y = c(0, 0, 1, 1, 0, 0, 1, 1), + z = c(1, 1, 2, 2, 2, 2, 3, 3) + ) + + + # stat_summary_2d + p1 <- ggplot(dat, aes(x, y, z = z)) + + stat_summary_2d(fun = function(x) mean(x)) + + p2 <- ggplot(dat, aes(x, y, z = z)) + + stat_summary_2d(fun = ~ mean(.x)) + + expect_equal( + layer_data(p1), + layer_data(p2) + ) + + + + # stat_summary_hex + # this plot is a bit funky, but easy to reason through + p1 <- ggplot(dat, aes(x, y, z = z)) + + stat_summary_hex(fun = function(x) mean(x)) + + p2 <- ggplot(dat, aes(x, y, z = z)) + + stat_summary_hex(fun = ~ mean(.x)) + + expect_equal( + layer_data(p1), + layer_data(p2) + ) + +})