Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.1.12
Version: 0.1.13
Authors@R: c(
person("Daniel J.", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
- adjust default quantiles throughout so that they match.
- force `layer_residual_quantiles()` to always include `0.5`.
- Rename `recipes:::check_training_set()` to `recipes:::validate_training_data()`, as it changed in recipes 1.1.0.
- A new column name duplicating an existing column name results in an error instead of a random name.

# epipredict 0.1

Expand Down
1 change: 0 additions & 1 deletion R/layer_cdc_flatline_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ slather.layer_cdc_flatline_quantiles <-
) %>%
select(all_of(c(avail_grps, ".pred_distn_all")))

# res <- check_pname(res, components$predictions, object)
components$predictions <- left_join(
components$predictions,
res,
Expand Down
7 changes: 6 additions & 1 deletion R/layer_point_from_distn.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ slather.layer_point_from_distn <-
components$predictions$.pred <- dstn
} else {
dstn <- tibble(dstn = dstn)
dstn <- check_pname(dstn, components$predictions, object)
dstn <- check_name(
dstn,
components$predictions,
object,
newname = object$name
)
components$predictions <- mutate(components$predictions, !!!dstn)
}
components
Expand Down
7 changes: 6 additions & 1 deletion R/layer_quantile_distn.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ slather.layer_quantile_distn <-
dstn <- snap(dstn, truncate[1], truncate[2])
}
dstn <- tibble(dstn = dstn)
dstn <- check_pname(dstn, components$predictions, object)
dstn <- check_name(
dstn,
components$predictions,
object,
newname = object$name
)
components$predictions <- mutate(components$predictions, !!!dstn)
components
}
Expand Down
7 changes: 6 additions & 1 deletion R/layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,12 @@ slather.layer_residual_quantiles <-

estimate <- components$predictions$.pred
res <- tibble(.pred_distn = r$dstn + estimate)
res <- check_pname(res, components$predictions, object)
res <- check_name(
res,
components$predictions,
object,
newname = object$name
)
components$predictions <- mutate(components$predictions, !!!res)
components
}
Expand Down
33 changes: 0 additions & 33 deletions R/utils-misc.R
Original file line number Diff line number Diff line change
@@ -1,36 +1,3 @@
#' Check that newly created variable names don't overlap
#'
#' `check_pname` is to be used in a slather method to ensure that
#' newly created variable names don't overlap with existing names.
#' Throws an warning if check fails, and creates a random string.
#' @param res A data frame or tibble of the newly created variables.
#' @param preds An epi_df or tibble containing predictions.
#' @param object A layer object passed to [slather()].
#' @param newname A string of variable names if the object doesn't contain a
#' $name element
#'
#' @keywords internal
check_pname <- function(res, preds, object, newname = NULL) {
if (is.null(newname)) newname <- object$name
new_preds_names <- colnames(preds)
intersection <- new_preds_names %in% newname
if (any(intersection)) {
newname <- rand_id(newname)
rlang::warn(
paste0(
"Name collision occured in `",
class(object)[1],
"`. The following variable names already exists: ",
paste0(new_preds_names[intersection], collapse = ", "),
". Result instead has randomly generated string `",
newname, "`."
)
)
}
names(res) <- newname
res
}

# Copied from `epiprocess`:

#' "Format" a character vector of column/variable names for cli interpolation
Expand Down
17 changes: 17 additions & 0 deletions tests/testthat/test-layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,23 @@ test_that("Returns expected number or rows and columns", {
expect_equal(unique(unnested$.pred_distn_quantile_level), c(.0275, 0.5, .8, .95))
})

tests_that("new name works correctly", {
f <- frosting() %>%
layer_predict() %>%
layer_naomit(.pred) %>%
layer_residual_quantiles(name = "foo")

wf1 <- wf %>% add_frosting(f)
expect_equal(names(forecast(wf1)), c("geo_value", "time_value", ".pred", "foo"))

f <- frosting() %>%
layer_predict() %>%
layer_naomit(.pred) %>%
layer_residual_quantiles(name = "geo_value")

wf1 <- wf %>% add_frosting(f)
expect_error(forecast(wf1))
})

test_that("Errors when used with a classifier", {
tib <- tibble(
Expand Down
Loading