Skip to content

Commit 46a2018

Browse files
authored
Merge pull request #499 from tidymodels/xgb-mtry
xgboost mtry parameter swap for #495
2 parents adf0f32 + 374dbf9 commit 46a2018

File tree

7 files changed

+128
-44
lines changed

7 files changed

+128
-44
lines changed

NEWS.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111

1212
* The `liquidSVM` engine for `svm_rbf()` was deprecated due to that package's removal from CRAN. (#425)
1313

14-
* New model specification `survival_reg()` for the new mode `"censored regression"` (#444). `surv_reg()` is now soft-deprecated (#448).
14+
* The xgboost engine for boosted trees was translating `mtry` to xgboost's `colsample_bytree`. We now map `mtry` to `colsample_bynode` since that is more consistent with how random forest works. `colsample_bytree` can still be optimized by passing it in as an engine argument. `colsample_bynode` was added to xgboost after the `parsnip` package code was written. (#495)
1515

16-
* New model specification `proportional_hazards()` for the `"censored regression"` mode (#451).
16+
* For xgboost boosting, `mtry` and `colsample_bytree` can be passed as integer counts or proportions while `subsample` and `validation` should be proportions. `xgb_train()` now has a new option `counts` for state what scale `mtry` and `colsample_bytree` are being used. (#461)
1717

1818
## Other Changes
1919

@@ -23,12 +23,8 @@
2323

2424
* Re-organized model documentation for `update` methods (#479).
2525

26-
27-
2826
* `generics::required_pkgs()` was extended for `parsnip` objects.
2927

30-
31-
3228
# parsnip 0.1.5
3329

3430
* An RStudio add-in is available that makes writing multiple `parsnip` model specifications to the source window. It can be accessed via the IDE addin menus or by calling `parsnip_addin()`.

R/boost_tree.R

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -264,20 +264,26 @@ check_args.boost_tree <- function(object) {
264264
#' @param max_depth An integer for the maximum depth of the tree.
265265
#' @param nrounds An integer for the number of boosting iterations.
266266
#' @param eta A numeric value between zero and one to control the learning rate.
267-
#' @param colsample_bytree Subsampling proportion of columns.
267+
#' @param colsample_bytree Subsampling proportion of columns for each tree.
268+
#' See the `counts` argument below. The default uses all columns.
269+
#' @param colsample_bynode Subsampling proportion of columns for each node
270+
#' within each tree. See the `counts` argument below. The default uses all
271+
#' columns.
268272
#' @param min_child_weight A numeric value for the minimum sum of instance
269273
#' weights needed in a child to continue to split.
270274
#' @param gamma A number for the minimum loss reduction required to make a
271275
#' further partition on a leaf node of the tree
272-
#' @param subsample Subsampling proportion of rows.
273-
#' @param validation A positive number. If on `[0, 1)` the value, `validation`
274-
#' is a random proportion of data in `x` and `y` that are used for performance
275-
#' assessment and potential early stopping. If 1 or greater, it is the _number_
276-
#' of training set samples use for these purposes.
276+
#' @param subsample Subsampling proportion of rows. By default, all of the
277+
#' training data are used.
278+
#' @param validation The _proportion_ of the data that are used for performance
279+
#' assessment and potential early stopping.
277280
#' @param early_stop An integer or `NULL`. If not `NULL`, it is the number of
278281
#' training iterations without improvement before stopping. If `validation` is
279282
#' used, performance is base on the validation set; otherwise, the training set
280283
#' is used.
284+
#' @param counts A logical. If `FALSE`, `colsample_bynode` and
285+
#' `colsample_bytree` are both assumed to be _proportions_ of the proportion of
286+
#' columns affects (instead of counts).
281287
#' @param objective A single string (or NULL) that defines the loss function that
282288
#' `xgboost` uses to create trees. See [xgboost::xgb.train()] for options. If left
283289
#' NULL, an appropriate loss function is chosen.
@@ -290,11 +296,10 @@ check_args.boost_tree <- function(object) {
290296
#' @export
291297
xgb_train <- function(
292298
x, y,
293-
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bytree = 1,
294-
min_child_weight = 1, gamma = 0, subsample = 1, validation = 0,
295-
early_stop = NULL, objective = NULL,
296-
event_level = c("first", "second"),
297-
...) {
299+
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL,
300+
colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
301+
validation = 0, early_stop = NULL, objective = NULL, counts = TRUE,
302+
event_level = c("first", "second"), ...) {
298303

299304
event_level <- rlang::arg_match(event_level, c("first", "second"))
300305
others <- list(...)
@@ -304,6 +309,7 @@ xgb_train <- function(
304309
if (!is.numeric(validation) || validation < 0 || validation >= 1) {
305310
rlang::abort("`validation` should be on [0, 1).")
306311
}
312+
307313
if (!is.null(early_stop)) {
308314
if (early_stop <= 1) {
309315
rlang::abort(paste0("`early_stop` should be on [2, ", nrounds, ")."))
@@ -313,7 +319,6 @@ xgb_train <- function(
313319
}
314320
}
315321

316-
317322
if (is.null(objective)) {
318323
if (is.numeric(y)) {
319324
objective <- "reg:squarederror"
@@ -331,19 +336,21 @@ xgb_train <- function(
331336

332337
x <- as_xgb_data(x, y, validation, event_level)
333338

334-
# translate `subsample` and `colsample_bytree` to be on (0, 1] if not
335-
if (subsample > 1) {
336-
subsample <- subsample/n
337-
}
338-
if (subsample > 1) {
339-
subsample <- 1
340-
}
341339

342-
if (colsample_bytree > 1) {
343-
colsample_bytree <- colsample_bytree/p
340+
if (!is.numeric(subsample) || subsample < 0 || subsample > 1) {
341+
rlang::abort("`subsample` should be on [0, 1].")
344342
}
345-
if (colsample_bytree > 1) {
343+
344+
# initialize
345+
if (is.null(colsample_bytree)) {
346346
colsample_bytree <- 1
347+
} else {
348+
colsample_bytree <- recalc_param(colsample_bytree, counts, p)
349+
}
350+
if (is.null(colsample_bynode)) {
351+
colsample_bynode <- 1
352+
} else {
353+
colsample_bynode <- recalc_param(colsample_bynode, counts, p)
347354
}
348355

349356
if (min_child_weight > n) {
@@ -358,6 +365,7 @@ xgb_train <- function(
358365
max_depth = max_depth,
359366
gamma = gamma,
360367
colsample_bytree = colsample_bytree,
368+
colsample_bynode = colsample_bynode,
361369
min_child_weight = min(min_child_weight, n),
362370
subsample = subsample,
363371
objective = objective
@@ -390,6 +398,30 @@ xgb_train <- function(
390398
eval_tidy(call, env = current_env())
391399
}
392400

401+
recalc_param <- function(x, counts, denom) {
402+
nm <- as.character(match.call()$x)
403+
if (is.null(x)) {
404+
x <- 1
405+
} else {
406+
if (counts) {
407+
maybe_proportion(x, nm)
408+
x <- min(denom, x)/denom
409+
}
410+
}
411+
x
412+
}
413+
414+
maybe_proportion <- function(x, nm) {
415+
if (x < 1) {
416+
msg <- paste0(
417+
"The option `counts = TRUE` was used but parameter `", nm,
418+
"` was given as ", signif(x, 3), ". Please use a value >= 1 or use ",
419+
"`counts = FALSE`."
420+
)
421+
rlang::abort(msg)
422+
}
423+
}
424+
393425
#' @importFrom stats binomial
394426
xgb_pred <- function(object, newdata, ...) {
395427
if (!inherits(newdata, "xgb.DMatrix")) {
@@ -432,7 +464,8 @@ as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {
432464

433465
if (!inherits(x, "xgb.DMatrix")) {
434466
if (validation > 0) {
435-
trn_index <- sample(1:n, size = floor(n * (1 - validation)) + 1)
467+
m <- floor(n * (1 - validation)) + 1
468+
trn_index <- sample(1:n, size = max(m, 2))
436469
wlist <-
437470
list(validation = xgboost::xgb.DMatrix(x[-trn_index, ], label = y[-trn_index], missing = NA))
438471
dat <- xgboost::xgb.DMatrix(x[trn_index, ], label = y[trn_index], missing = NA)

R/boost_tree_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ set_model_arg(
3737
model = "boost_tree",
3838
eng = "xgboost",
3939
parsnip = "mtry",
40-
original = "colsample_bytree",
40+
original = "colsample_bynode",
4141
func = list(pkg = "dials", fun = "mtry"),
4242
has_submodel = FALSE
4343
)

man/boost_tree.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/boost-tree.Rmd

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ mod_param <-
3838
update(sample_size = sample_prop(c(0.4, 0.9)))
3939
```
4040

41-
For this engine, tuning over `trees` is very efficient since the same model
42-
object can be used to make predictions over multiple values of `trees`.
41+
For this engine, tuning over `trees` is very efficient since the same model object can be used to make predictions over multiple values of `trees`.
4342

4443
Note that `xgboost` models require that non-numeric predictors (e.g., factors) must be converted to dummy variables or some other numeric representation. By default, when using `fit()` with `xgboost`, a one-hot encoding is used to convert factor predictors to indicator variables.
4544

@@ -89,7 +88,7 @@ get_defaults_boost_tree <- function() {
8988
"boost_tree", "xgboost", "tree_depth", "max_depth", get_arg("parsnip", "xgb_train", "max_depth"),
9089
"boost_tree", "xgboost", "trees", "nrounds", get_arg("parsnip", "xgb_train", "nrounds"),
9190
"boost_tree", "xgboost", "learn_rate", "eta", get_arg("parsnip", "xgb_train", "eta"),
92-
"boost_tree", "xgboost", "mtry", "colsample_bytree", get_arg("parsnip", "xgb_train", "colsample_bytree"),
91+
"boost_tree", "xgboost", "mtry", "colsample_bynode", get_arg("parsnip", "xgb_train", "colsample_bynode"),
9392
"boost_tree", "xgboost", "min_n", "min_child_weight", get_arg("parsnip", "xgb_train", "min_child_weight"),
9493
"boost_tree", "xgboost", "loss_reduction", "gamma", get_arg("parsnip", "xgb_train", "gamma"),
9594
"boost_tree", "xgboost", "sample_size", "subsample", get_arg("parsnip", "xgb_train", "subsample"),

man/xgb_train.Rd

Lines changed: 17 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,9 +414,9 @@ test_that('argument checks for data dimensions', {
414414
xy_fit <- spec %>% fit_xy(x = penguins_dummy, y = penguins$species),
415415
"1000 samples were requested"
416416
)
417-
expect_equal(f_fit$fit$params$colsample_bytree, 1)
417+
expect_equal(f_fit$fit$params$colsample_bynode, 1)
418418
expect_equal(f_fit$fit$params$min_child_weight, nrow(penguins))
419-
expect_equal(xy_fit$fit$params$colsample_bytree, 1)
419+
expect_equal(xy_fit$fit$params$colsample_bynode, 1)
420420
expect_equal(xy_fit$fit$params$min_child_weight, nrow(penguins))
421421

422422
})
@@ -482,3 +482,49 @@ test_that("fit and prediction with `event_level`", {
482482
expect_equal(pred_p_2[[".pred_male"]], pred_xgb_2)
483483

484484
})
485+
486+
test_that("count/proportion parameters", {
487+
skip_if_not_installed("xgboost")
488+
fit1 <-
489+
boost_tree(mtry = 7, trees = 4) %>%
490+
set_engine("xgboost") %>%
491+
set_mode("regression") %>%
492+
fit(mpg ~ ., data = mtcars)
493+
expect_equal(fit1$fit$params$colsample_bytree, 1)
494+
expect_equal(fit1$fit$params$colsample_bynode, 7/(ncol(mtcars) - 1))
495+
496+
fit2 <-
497+
boost_tree(mtry = 7, trees = 4) %>%
498+
set_engine("xgboost", colsample_bytree = 4) %>%
499+
set_mode("regression") %>%
500+
fit(mpg ~ ., data = mtcars)
501+
expect_equal(fit2$fit$params$colsample_bytree, 4/(ncol(mtcars) - 1))
502+
expect_equal(fit2$fit$params$colsample_bynode, 7/(ncol(mtcars) - 1))
503+
504+
fit3 <-
505+
boost_tree(trees = 4) %>%
506+
set_engine("xgboost") %>%
507+
set_mode("regression") %>%
508+
fit(mpg ~ ., data = mtcars)
509+
expect_equal(fit3$fit$params$colsample_bytree, 1)
510+
expect_equal(fit3$fit$params$colsample_bynode, 1)
511+
512+
fit4 <-
513+
boost_tree(mtry = .9, trees = 4) %>%
514+
set_engine("xgboost", colsample_bytree = .1, counts = FALSE) %>%
515+
set_mode("regression") %>%
516+
fit(mpg ~ ., data = mtcars)
517+
expect_equal(fit4$fit$params$colsample_bytree, .1)
518+
expect_equal(fit4$fit$params$colsample_bynode, .9)
519+
520+
expect_error(
521+
boost_tree(mtry = .9, trees = 4) %>%
522+
set_engine("xgboost") %>%
523+
set_mode("regression") %>%
524+
fit(mpg ~ ., data = mtcars),
525+
"was given as 0.9"
526+
)
527+
528+
})
529+
530+

0 commit comments

Comments
 (0)