diff --git a/R/predict_class.R b/R/predict_class.R index e96d460b1..3c8fe69a6 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -9,21 +9,23 @@ #' @export predict_class.model_fit #' @export predict_class.model_fit <- function(object, new_data, ...) { - if (object$spec$mode != "classification") - rlang::abort("`predict.model_fit()` is for predicting factor outcomes.") + if (object$spec$mode != "classification") { + cli::cli_abort("{.fun predict.model_fit} is for predicting factor outcomes.") + } check_spec_pred_type(object, "class") if (inherits(object$fit, "try-error")) { - rlang::warn("Model fit failed; cannot make predictions.") + cli::cli_warn("Model fit failed; cannot make predictions.") return(NULL) } new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$class$pre)) + if (!is.null(object$spec$method$pred$class$pre)) { new_data <- object$spec$method$pred$class$pre(new_data, object) + } # create prediction call pred_call <- make_pred_call(object$spec$method$pred$class) @@ -56,6 +58,6 @@ predict_class.model_fit <- function(object, new_data, ...) { # @keywords internal # @rdname other_predict # @inheritParams predict.model_fit -predict_class <- function(object, ...) +predict_class <- function(object, ...) { UseMethod("predict_class") - +} diff --git a/R/predict_classprob.R b/R/predict_classprob.R index 19190f8a7..4fb5fb957 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -5,22 +5,24 @@ #' @export predict_classprob.model_fit #' @export predict_classprob.model_fit <- function(object, new_data, ...) { - if (object$spec$mode != "classification") - rlang::abort("`predict.model_fit()` is for predicting factor outcomes.") + if (object$spec$mode != "classification") { + cli::cli_abort("{.fun predict.model_fit()} is for predicting factor outcomes.") + } check_spec_pred_type(object, "prob") check_spec_levels(object) if (inherits(object$fit, "try-error")) { - rlang::warn("Model fit failed; cannot make predictions.") + cli::cli_warn("Model fit failed; cannot make predictions.") return(NULL) } new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$prob$pre)) + if (!is.null(object$spec$method$pred$prob$pre)) { new_data <- object$spec$method$pred$prob$pre(new_data, object) + } # create prediction call pred_call <- make_pred_call(object$spec$method$pred$prob) @@ -33,11 +35,13 @@ predict_classprob.model_fit <- function(object, new_data, ...) { } # check and sort names - if (!is.data.frame(res) & !inherits(res, "tbl_spark")) - rlang::abort("The was a problem with the probability predictions.") + if (!is.data.frame(res) & !inherits(res, "tbl_spark")) { + cli::cli_abort("The was a problem with the probability predictions.") + } - if (!is_tibble(res) & !inherits(res, "tbl_spark")) + if (!is_tibble(res) & !inherits(res, "tbl_spark")) { res <- as_tibble(res) + } res } @@ -46,18 +50,19 @@ predict_classprob.model_fit <- function(object, new_data, ...) { # @keywords internal # @rdname other_predict # @inheritParams predict.model_fit -predict_classprob <- function(object, ...) +predict_classprob <- function(object, ...) { UseMethod("predict_classprob") +} check_spec_levels <- function(spec) { if ("class" %in% spec$lvl) { - rlang::abort( - glue::glue( - "The outcome variable `{spec$preproc$y_var}` has a level called 'class'. ", - "This value is reserved for parsnip's classification internals; please ", - "change the levels, perhaps with `forcats::fct_relevel()`." - ), - call = NULL + cli::cli_abort( + c( + "The outcome variable {.var {spec$preproc$y_var}} has a level called {.val class}.", + "i" = "This value is reserved for parsnip's classification internals; please + change the levels, perhaps with {.fn forcats::fct_relevel}.", + call = NULL + ) ) } } diff --git a/R/predict_numeric.R b/R/predict_numeric.R index 30cdc6a83..6f8aed916 100644 --- a/R/predict_numeric.R +++ b/R/predict_numeric.R @@ -5,29 +5,35 @@ #' @export predict_numeric.model_fit #' @export predict_numeric.model_fit <- function(object, new_data, ...) { - if (object$spec$mode != "regression") - rlang::abort(glue::glue("`predict_numeric()` is for predicting numeric outcomes. ", - "Use `predict_class()` or `predict_classprob()` for ", - "classification models.")) + if (object$spec$mode != "regression") { + cli::cli_abort( + c( + "{.fun predict_numeric} is for predicting numeric outcomes.", + "i" = "Use {.fun predict_class} or {.fun predict_classprob} for + classification models." + ) + ) + } check_spec_pred_type(object, "numeric") if (inherits(object$fit, "try-error")) { - rlang::warn("Model fit failed; cannot make predictions.") + cli::cli_warn("Model fit failed; cannot make predictions.") return(NULL) } new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$numeric$pre)) + if (!is.null(object$spec$method$pred$numeric$pre)) { new_data <- object$spec$method$pred$numeric$pre(new_data, object) + } # create prediction call pred_call <- make_pred_call(object$spec$method$pred$numeric) res <- eval_tidy(pred_call) - + # post-process the predictions if (!is.null(object$spec$method$pred$numeric$post)) { res <- object$spec$method$pred$numeric$post(res, object) @@ -36,8 +42,9 @@ predict_numeric.model_fit <- function(object, new_data, ...) { if (is.vector(res)) { res <- unname(res) } else { - if (!inherits(res, "tbl_spark")) + if (!inherits(res, "tbl_spark")) { res <- as.data.frame(res) + } } res } @@ -47,5 +54,6 @@ predict_numeric.model_fit <- function(object, new_data, ...) { #' @keywords internal #' @rdname other_predict #' @inheritParams predict_numeric.model_fit -predict_numeric <- function(object, ...) +predict_numeric <- function(object, ...) { UseMethod("predict_numeric") +} diff --git a/R/predict_time.R b/R/predict_time.R index 2c512c103..769b7a578 100644 --- a/R/predict_time.R +++ b/R/predict_time.R @@ -5,29 +5,35 @@ #' @export predict_time.model_fit #' @export predict_time.model_fit <- function(object, new_data, ...) { - if (object$spec$mode != "censored regression") - rlang::abort(glue::glue("`predict_time()` is for predicting time outcomes. ", - "Use `predict_class()` or `predict_classprob()` for ", - "classification models.")) + if (object$spec$mode != "censored regression") { + cli::cli_abort( + c( + "{.fun predict_time} is for predicting time outcomes.", + "i" = "Use {.fun predict_class} or {.fun predict_classprob} for + classification models." + ) + ) + } check_spec_pred_type(object, "time") if (inherits(object$fit, "try-error")) { - rlang::warn("Model fit failed; cannot make predictions.") + cli::cli_warn("Model fit failed; cannot make predictions.") return(NULL) } new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$time$pre)) + if (!is.null(object$spec$method$pred$time$pre)) { new_data <- object$spec$method$pred$time$pre(new_data, object) + } # create prediction call pred_call <- make_pred_call(object$spec$method$pred$time) res <- eval_tidy(pred_call) - + # post-process the predictions if (!is.null(object$spec$method$pred$time$post)) { res <- object$spec$method$pred$time$post(res, object) @@ -45,5 +51,6 @@ predict_time.model_fit <- function(object, new_data, ...) { #' @keywords internal #' @rdname other_predict #' @inheritParams predict_time.model_fit -predict_time <- function(object, ...) +predict_time <- function(object, ...) { UseMethod("predict_time") +} diff --git a/tests/testthat/test-predict_formats.R b/tests/testthat/test-predict_formats.R index ff6fa4475..2dbf38b44 100644 --- a/tests/testthat/test-predict_formats.R +++ b/tests/testthat/test-predict_formats.R @@ -76,7 +76,7 @@ test_that('predict(type = "prob") with level "class" (see #720)', { ) expect_error( - regexp = "variable `boop` has a level called 'class'", + regexp = 'variable `boop` has a level called "class"', predict(mod, type = "prob", new_data = x) ) })