Skip to content

Adjust ahead #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 92 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
249b465
fix warnings and empty tests
dsweber2 Mar 16, 2024
c622d7d
first draft of extend_ahead
dsweber2 Mar 18, 2024
e6c19e9
extend_ahead version bump and news
dsweber2 May 3, 2024
2b68062
styler has opinions
dsweber2 Mar 18, 2024
2f1ab39
separate step version
dsweber2 Mar 29, 2024
07a9e28
styler
dsweber2 Mar 29, 2024
d4d617f
tests for utils-latency and accompanying fixes
dsweber2 Apr 1, 2024
80e64b5
adding stringr
dsweber2 Apr 1, 2024
aa87607
old snapshots, select prefers `all_of` for vectors
dsweber2 Apr 1, 2024
375af6d
local renv way out of date
dsweber2 Apr 1, 2024
24eca50
pkgdown needs @keywords internal
dsweber2 Apr 1, 2024
9e9b1b4
passes local tests after updating
dsweber2 Apr 2, 2024
7ae26d2
back to skipping some population_scaling tests
dsweber2 Apr 2, 2024
e346f71
step_adjust_latency works on tests
dsweber2 Apr 25, 2024
05b5cbf
spurious lifecycle addition removed
dsweber2 May 3, 2024
47cb5b7
fixing RMDcheck remote
dsweber2 May 3, 2024
4b0b668
nothing but `rlang::abort` -> `cli::cli_abort`s
dsweber2 May 3, 2024
2731160
smaller suggestions and styling
dsweber2 May 3, 2024
5c1e15e
smaller suggestions: local tests passing again
dsweber2 May 6, 2024
c48e81a
moving shift detection earlier,dropping string*dep
dsweber2 May 8, 2024
8028374
+purrr, styling
dsweber2 May 8, 2024
4a0ed48
glue -> glue::glue
dsweber2 May 8, 2024
909e47c
fix get_latent_column_tibble docs
dsweber2 May 8, 2024
8639ebd
step_adjust_latency arg docs
dsweber2 May 8, 2024
55314a8
rec formatting things, dropping `purrr`
dsweber2 May 13, 2024
ce230ac
glue->paste, dropping zoo
dsweber2 May 14, 2024
c8f6b85
Detecting required/forbidden steps beforehand
dsweber2 May 16, 2024
4927f0e
minor rebase woes
dsweber2 May 17, 2024
7752b17
tests for utils-latency and accompanying fixes
dsweber2 Apr 1, 2024
8f3641b
adding stringr
dsweber2 Apr 1, 2024
ba0c4b8
nothing but `rlang::abort` -> `cli::cli_abort`s
dsweber2 May 3, 2024
27694ef
moving shift detection earlier,dropping string*dep
dsweber2 May 8, 2024
3eab9c2
rec formatting things, dropping `purrr`
dsweber2 May 13, 2024
7aa06e7
initial layer adjustments
dsweber2 May 15, 2024
be3474c
namespace and doc fixes
dsweber2 May 17, 2024
6c158ce
full rebase fixes
dsweber2 May 17, 2024
4f71715
adding latency adjusting to arx_forecaster
dsweber2 May 17, 2024
e102d41
arx_classifier more or less free
dsweber2 May 17, 2024
65535c5
formatting and snapshots
dsweber2 May 17, 2024
5d5cfbb
updated man pages
dsweber2 May 22, 2024
084acb6
group_by options to get the max_time_value
dsweber2 May 24, 2024
d2e2f95
PR review recs
dsweber2 May 29, 2024
99d8099
typo in multiline pipe replacement
dsweber2 May 29, 2024
c4fce2e
happy styler
dsweber2 Jun 3, 2024
f5ae9d1
various requested changes, check passes
dsweber2 Jun 14, 2024
be9607b
style fix
dsweber2 Jun 14, 2024
a86b3c7
inheritParams, correct print, test adjust subset
dsweber2 Jun 14, 2024
5b7eff1
space
dsweber2 Jun 14, 2024
63b02c9
print fix and tests
dsweber2 Jun 22, 2024
fc8b0c0
multi-aheads do work
dsweber2 Jun 24, 2024
a570a0e
arx_fc better fc_date info, docs
dsweber2 Jun 28, 2024
a65cad0
classifier latency ahead adjustment
dsweber2 Jun 28, 2024
e3a368e
refactor step_adjust_ahead to be early step
dsweber2 Jul 3, 2024
09fbfd8
moving locf to step_adjust_ahead instead of get_test_data
dsweber2 Jul 3, 2024
01dc148
hotfix from Dan
dsweber2 Jul 8, 2024
a5a84a7
rebase fixes, error classes, unskip latency tests
dsweber2 Jul 30, 2024
bfde279
rebase fixes round 2
dsweber2 Sep 4, 2024
8119e72
NEWS+Description, partial locf tests, docs
dsweber2 Sep 4, 2024
7816d13
testing the step
dsweber2 Sep 4, 2024
edc2b3f
step locf tests passing, grf pkgdown
dsweber2 Sep 5, 2024
acc5fa0
locf correct on NA at end columns
dsweber2 Sep 6, 2024
a11fa5f
docs along with more extensive tests
dsweber2 Sep 9, 2024
71694d7
non-timezone dependent printing tests
dsweber2 Sep 9, 2024
aa9f1e3
arx_forecaster consistency check and tests
dsweber2 Sep 9, 2024
899ea51
arx_forecaster updates
dsweber2 Sep 10, 2024
af45eaf
arx_classifier addition
dsweber2 Sep 10, 2024
1155d30
formatting
dsweber2 Sep 10, 2024
343551f
various minor fixes caught in pre-review
dsweber2 Sep 11, 2024
60827a3
spurious join_by tests removed
dsweber2 Sep 11, 2024
60b3fad
vignettes: no more fill_locf, some missing data
dsweber2 Sep 11, 2024
28c5863
drop :: for cli, many dplyr commands
dsweber2 Sep 12, 2024
63db4d5
various recommendations
dsweber2 Sep 13, 2024
b5ed1b3
remaining recs besides metadata to term_info
dsweber2 Sep 13, 2024
50cecae
moving step checks to a separate function, styler
dsweber2 Sep 16, 2024
7f18662
NAomit can happen before (but probably shouldn't)
dsweber2 Sep 16, 2024
4b37f7e
draft implementation
dajmcdon Sep 16, 2024
9a7abc1
refactor the utility fun
dajmcdon Sep 16, 2024
82f1165
only do processing if locf
dajmcdon Sep 16, 2024
9044d5a
Update R/utils-latency.R
dajmcdon Sep 17, 2024
4a4a90c
fix needles/haystack bug
dajmcdon Sep 17, 2024
b129829
some formatting
dsweber2 Sep 17, 2024
6e5d6fc
fixed
dsweber2 Sep 17, 2024
388ccc1
single letter variables are impossible to search
dsweber2 Sep 17, 2024
54993d2
lat adj tag for steps which have been modified
dsweber2 Sep 17, 2024
e70d553
remove actual changes to prep.epi_recipe
dsweber2 Sep 20, 2024
b3c96b3
various requests and rebasing on dev
dsweber2 Sep 27, 2024
86c46a4
updating after rebase
dsweber2 Sep 30, 2024
90edb46
final requests
dsweber2 Sep 30, 2024
ce99138
only adding metadata if given an epi_df originally
dsweber2 Sep 30, 2024
c6800bb
snapshot updates
dsweber2 Sep 30, 2024
561570e
description and News
dsweber2 Oct 1, 2024
053b501
rerererebase
dsweber2 Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
^doc$
^Meta$
^.lintr$
^.venv$
^.venv$
^inst/templates$
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.1.0
Version: 0.1.1
Authors@R: c(
person("Daniel J.", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down Expand Up @@ -40,6 +40,7 @@ Imports:
magrittr,
recipes (>= 1.0.4),
rlang (>= 1.1.0),
purrr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need it. See below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you only use purrr::map(), we already have this, so purrr only needs to be in 'Suggests'

stats,
tibble,
tidyr,
Expand Down
2 changes: 2 additions & 0 deletions DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ The `main` version is available at `file:///<local path>/epidatr/epipredict/inde
You can also build the docs manually and launch the site with python. From the terminal, this looks like

```bash
R -e 'pkgdown::clean_site()'
R -e 'devtools::document()'
R -e 'pkgdown::build_site()'
python -m http.server -d docs
```

Expand Down
26 changes: 25 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ S3method(autoplot,canned_epipred)
S3method(autoplot,epi_workflow)
S3method(bake,check_enough_train_data)
S3method(bake,epi_recipe)
S3method(bake,step_adjust_latency)
S3method(bake,step_epi_ahead)
S3method(bake,step_epi_lag)
S3method(bake,step_epi_slide)
Expand Down Expand Up @@ -58,6 +59,7 @@ S3method(predict,epi_workflow)
S3method(predict,flatline)
S3method(prep,check_enough_train_data)
S3method(prep,epi_recipe)
S3method(prep,step_adjust_latency)
S3method(prep,step_epi_ahead)
S3method(prep,step_epi_lag)
S3method(prep,step_epi_slide)
Expand Down Expand Up @@ -87,6 +89,7 @@ S3method(print,layer_quantile_distn)
S3method(print,layer_residual_quantiles)
S3method(print,layer_threshold)
S3method(print,layer_unnest)
S3method(print,step_adjust_latency)
S3method(print,step_epi_ahead)
S3method(print,step_epi_lag)
S3method(print,step_epi_slide)
Expand Down Expand Up @@ -195,6 +198,7 @@ export(remove_frosting)
export(remove_model)
export(slather)
export(smooth_quantile_reg)
export(step_adjust_latency)
export(step_epi_ahead)
export(step_epi_lag)
export(step_epi_naomit)
Expand Down Expand Up @@ -225,6 +229,7 @@ importFrom(checkmate,test_numeric)
importFrom(checkmate,test_scalar)
importFrom(cli,cli_abort)
importFrom(cli,cli_warn)
importFrom(dplyr,"%>%")
importFrom(dplyr,across)
importFrom(dplyr,all_of)
importFrom(dplyr,any_of)
Expand All @@ -235,13 +240,20 @@ importFrom(dplyr,everything)
importFrom(dplyr,filter)
importFrom(dplyr,full_join)
importFrom(dplyr,group_by)
importFrom(dplyr,group_by_at)
importFrom(dplyr,join_by)
importFrom(dplyr,left_join)
importFrom(dplyr,mutate)
importFrom(dplyr,n)
importFrom(dplyr,pull)
importFrom(dplyr,relocate)
importFrom(dplyr,rename)
importFrom(dplyr,rowwise)
importFrom(dplyr,select)
importFrom(dplyr,summarise)
importFrom(dplyr,summarize)
importFrom(dplyr,tibble)
importFrom(dplyr,tribble)
importFrom(dplyr,ungroup)
importFrom(epiprocess,epi_slide)
importFrom(epiprocess,growth_rate)
Expand All @@ -255,18 +267,20 @@ importFrom(ggplot2,geom_line)
importFrom(ggplot2,geom_linerange)
importFrom(ggplot2,geom_point)
importFrom(ggplot2,geom_ribbon)
importFrom(glue,glue)
importFrom(hardhat,extract_recipe)
importFrom(hardhat,refresh_blueprint)
importFrom(hardhat,run_mold)
importFrom(magrittr,"%>%")
importFrom(recipes,bake)
importFrom(recipes,detect_step)
importFrom(recipes,prep)
importFrom(recipes,rand_id)
importFrom(rlang,"!!!")
importFrom(rlang,"!!")
importFrom(rlang,"%@%")
importFrom(rlang,"%||%")
importFrom(rlang,":=")
importFrom(rlang,abort)
importFrom(rlang,arg_match)
importFrom(rlang,as_function)
importFrom(rlang,caller_arg)
Expand All @@ -276,16 +290,19 @@ importFrom(rlang,enquos)
importFrom(rlang,expr)
importFrom(rlang,global_env)
importFrom(rlang,inject)
importFrom(rlang,is_empty)
importFrom(rlang,is_logical)
importFrom(rlang,is_null)
importFrom(rlang,is_true)
importFrom(rlang,list2)
importFrom(rlang,set_names)
importFrom(rlang,sym)
importFrom(stats,as.formula)
importFrom(stats,family)
importFrom(stats,lm)
importFrom(stats,median)
importFrom(stats,model.frame)
importFrom(stats,na.omit)
importFrom(stats,poly)
importFrom(stats,predict)
importFrom(stats,qnorm)
Expand All @@ -294,6 +311,12 @@ importFrom(stats,residuals)
importFrom(tibble,as_tibble)
importFrom(tibble,tibble)
importFrom(tidyr,crossing)
importFrom(tidyr,drop_na)
importFrom(tidyr,expand_grid)
importFrom(tidyr,fill)
importFrom(tidyr,unnest)
importFrom(tidyselect,all_of)
importFrom(utils,capture.output)
importFrom(vctrs,as_list_of)
importFrom(vctrs,field)
importFrom(vctrs,new_rcrd)
Expand All @@ -303,3 +326,4 @@ importFrom(vctrs,vec_data)
importFrom(vctrs,vec_ptype_abbr)
importFrom(vctrs,vec_ptype_full)
importFrom(vctrs,vec_recycle_common)
importFrom(workflows,extract_preprocessor)
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicate PR's.

# epipredict 0.2

## features
- Add `step_adjust_latency`, which give several methods to adjust the forecast if the `forecast_date` is after the last day of data.

## bugfixes

# epipredict 0.1

- simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()`
Expand Down
90 changes: 63 additions & 27 deletions R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,18 @@ arx_classifier <- function(
wf <- arx_class_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
wf <- fit(wf, epi_data)

if (args_list$adjust_latency == "none") {
forecast_date_default <- max(epi_data$time_value)
if (!is.null(args_list$forecast_date) && args_list$forecast_date != forecast_date_default) {
cli_warn("The specified forecast date {args_list$forecast_date} doesn't match the date from which the forecast is occurring {forecast_date}.")
}
} else {
forecast_date_default <- attributes(epi_data)$metadata$as_of
}
forecast_date <- args_list$forecast_date %||% forecast_date_default
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
preds <- forecast(
wf,
fill_locf = TRUE,
n_recent = args_list$nafill_buffer,
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
) %>%
as_tibble() %>%
select(-time_value)
Expand Down Expand Up @@ -125,27 +132,39 @@ arx_class_epi_workflow <- function(
if (!(is.null(trainer) || is_classification(trainer))) {
cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
}

if (args_list$adjust_latency == "none") {
forecast_date_default <- max(epi_data$time_value)
if (!is.null(args_list$forecast_date) && args_list$forecast_date != forecast_date_default) {
cli_warn("The specified forecast date {args_list$forecast_date} doesn't match the date from which the forecast is occurring {forecast_date}.")
}
} else {
forecast_date_default <- attributes(epi_data)$metadata$as_of
}
forecast_date <- args_list$forecast_date %||% forecast_date_default
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)

lags <- arx_lags_validator(predictors, args_list$lags)

# --- preprocessor
# ------- predictors
r <- epi_recipe(epi_data) %>%
step_growth_rate(
dplyr::all_of(predictors),
all_of(predictors),
role = "grp",
horizon = args_list$horizon,
method = args_list$method,
log_scale = args_list$log_scale,
additional_gr_args_list = args_list$additional_gr_args
)
for (l in seq_along(lags)) {
p <- predictors[l]
p <- as.character(glue::glue_data(args_list, "gr_{horizon}_{method}_{p}"))
r <- step_epi_lag(r, !!p, lag = lags[[l]])
pred_names <- predictors[l]
pred_names <- as.character(glue::glue_data(args_list, "gr_{horizon}_{method}_{pred_names}"))
r <- step_epi_lag(r, !!pred_names, lag = lags[[l]])
}
# ------- outcome
if (args_list$outcome_transform == "lag_difference") {
o <- as.character(
pre_out_name <- as.character(
glue::glue_data(args_list, "lag_diff_{horizon}_{outcome}")
)
r <- r %>%
Expand All @@ -156,7 +175,7 @@ arx_class_epi_workflow <- function(
)
}
if (args_list$outcome_transform == "growth_rate") {
o <- as.character(
pre_out_name <- as.character(
glue::glue_data(args_list, "gr_{horizon}_{method}_{outcome}")
)
if (!(outcome %in% predictors)) {
Expand All @@ -171,11 +190,30 @@ arx_class_epi_workflow <- function(
)
}
}
o2 <- rlang::sym(paste0("ahead_", args_list$ahead, "_", o))
# regex that will match any amount of adjustment for the ahead
ahead_out_name_regex <- glue::glue("ahead_[0-9]*_{pre_out_name}")
method_adjust_latency <- args_list$adjust_latency
if (method_adjust_latency != "none") {
if (method_adjust_latency != "extend_ahead") {
cli_abort("only extend_ahead is currently supported",
class = "epipredict__arx_classifier__adjust_latency_unsupported_method"
)
}
r <- r %>% step_adjust_latency(!!pre_out_name,
fixed_forecast_date = forecast_date,
method = method_adjust_latency
)
}
r <- r %>%
step_epi_ahead(!!o, ahead = args_list$ahead, role = "pre-outcome") %>%
recipes::step_mutate(
outcome_class = cut(!!o2, breaks = args_list$breaks),
step_epi_ahead(!!pre_out_name, ahead = args_list$ahead, role = "pre-outcome")
r <- r %>%
step_mutate(
across(
matches(ahead_out_name_regex),
~ cut(.x, breaks = args_list$breaks),
.names = "outcome_class",
.unpack = TRUE
),
role = "outcome"
) %>%
step_epi_naomit() %>%
Expand All @@ -192,10 +230,6 @@ arx_class_epi_workflow <- function(
)
}


forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)

# --- postprocessor
f <- frosting() %>% layer_predict() # %>% layer_naomit()
f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>%
Expand Down Expand Up @@ -260,13 +294,14 @@ arx_class_args_list <- function(
n_training = Inf,
forecast_date = NULL,
target_date = NULL,
adjust_latency = c("none", "extend_ahead", "extend_lags", "locf"),
warn_latency = TRUE,
outcome_transform = c("growth_rate", "lag_difference"),
breaks = 0.25,
horizon = 7L,
method = c("rel_change", "linear_reg"),
log_scale = FALSE,
additional_gr_args = list(),
nafill_buffer = Inf,
check_enough_data_n = NULL,
check_enough_data_epi_keys = NULL,
...) {
Expand All @@ -276,15 +311,15 @@ arx_class_args_list <- function(
method <- rlang::arg_match(method)
outcome_transform <- rlang::arg_match(outcome_transform)

arg_is_scalar(ahead, n_training, horizon, log_scale)
adjust_latency <- rlang::arg_match(adjust_latency)
arg_is_scalar(ahead, n_training, horizon, log_scale, adjust_latency, warn_latency)
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
arg_is_date(forecast_date, target_date, allow_null = TRUE)
arg_is_nonneg_int(ahead, lags, horizon)
arg_is_numeric(breaks)
arg_is_lgl(log_scale)
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
if (!is.list(additional_gr_args)) {
cli_abort(c(
"`additional_gr_args` must be a {.cls list}.",
Expand All @@ -297,10 +332,13 @@ arx_class_args_list <- function(

if (!is.null(forecast_date) && !is.null(target_date)) {
if (forecast_date + ahead != target_date) {
cli::cli_warn(c(
"`forecast_date` + `ahead` must equal `target_date`.",
i = "{.val {forecast_date}} + {.val {ahead}} != {.val {target_date}}."
))
cli_warn(
paste0(
"`forecast_date` {.val {forecast_date}} +",
" `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}."
),
class = "epipredict__arx_args__inconsistent_target_ahead_forecaste_date"
)
}
}

Expand All @@ -318,13 +356,13 @@ arx_class_args_list <- function(
breaks,
forecast_date,
target_date,
adjust_latency,
outcome_transform,
max_lags,
horizon,
method,
log_scale,
additional_gr_args,
nafill_buffer,
check_enough_data_n,
check_enough_data_epi_keys
),
Expand All @@ -337,5 +375,3 @@ print.arx_class <- function(x, ...) {
name <- "ARX Classifier"
NextMethod(name = name, ...)
}

# this is a trivial change to induce a check
Loading
Loading