Skip to content

Commit adc9d77

Browse files
committed
xgboost mtry parameter swap for #495
1 parent adf0f32 commit adc9d77

File tree

7 files changed

+38
-14
lines changed

7 files changed

+38
-14
lines changed

NEWS.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111

1212
* The `liquidSVM` engine for `svm_rbf()` was deprecated due to that package's removal from CRAN. (#425)
1313

14-
* New model specification `survival_reg()` for the new mode `"censored regression"` (#444). `surv_reg()` is now soft-deprecated (#448).
15-
16-
* New model specification `proportional_hazards()` for the `"censored regression"` mode (#451).
14+
* The xgboost engine for boosted trees was translating `mtry` to xgboost's `colsample_bytree`. We now map `mtry` to `colsample_bynode` since that is more consistent with how random forest works. `colsample_bytree` can still be optimized by passing it in as an engine argument. (#495)
1715

1816
## Other Changes
1917

R/boost_tree.R

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,9 @@ check_args.boost_tree <- function(object) {
264264
#' @param max_depth An integer for the maximum depth of the tree.
265265
#' @param nrounds An integer for the number of boosting iterations.
266266
#' @param eta A numeric value between zero and one to control the learning rate.
267-
#' @param colsample_bytree Subsampling proportion of columns.
267+
#' @param colsample_bytree Subsampling proportion of columns for each tree.
268+
#' @param colsample_bynode Subsampling proportion of columns for each node
269+
#' within each tree.
268270
#' @param min_child_weight A numeric value for the minimum sum of instance
269271
#' weights needed in a child to continue to split.
270272
#' @param gamma A number for the minimum loss reduction required to make a
@@ -290,8 +292,8 @@ check_args.boost_tree <- function(object) {
290292
#' @export
291293
xgb_train <- function(
292294
x, y,
293-
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bytree = 1,
294-
min_child_weight = 1, gamma = 0, subsample = 1, validation = 0,
295+
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = 1,
296+
colsample_bytree = 1, min_child_weight = 1, gamma = 0, subsample = 1, validation = 0,
295297
early_stop = NULL, objective = NULL,
296298
event_level = c("first", "second"),
297299
...) {
@@ -346,6 +348,13 @@ xgb_train <- function(
346348
colsample_bytree <- 1
347349
}
348350

351+
if (colsample_bynode > 1) {
352+
colsample_bynode <- colsample_bynode/p
353+
}
354+
if (colsample_bynode > 1) {
355+
colsample_bynode <- 1
356+
}
357+
349358
if (min_child_weight > n) {
350359
msg <- paste0(min_child_weight, " samples were requested but there were ",
351360
n, " rows in the data. ", n, " will be used.")
@@ -358,6 +367,7 @@ xgb_train <- function(
358367
max_depth = max_depth,
359368
gamma = gamma,
360369
colsample_bytree = colsample_bytree,
370+
colsample_bynode = colsample_bynode,
361371
min_child_weight = min(min_child_weight, n),
362372
subsample = subsample,
363373
objective = objective

R/boost_tree_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ set_model_arg(
3737
model = "boost_tree",
3838
eng = "xgboost",
3939
parsnip = "mtry",
40-
original = "colsample_bytree",
40+
original = "colsample_bynode",
4141
func = list(pkg = "dials", fun = "mtry"),
4242
has_submodel = FALSE
4343
)

man/boost_tree.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/boost-tree.Rmd

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ mod_param <-
3838
update(sample_size = sample_prop(c(0.4, 0.9)))
3939
```
4040

41-
For this engine, tuning over `trees` is very efficient since the same model
42-
object can be used to make predictions over multiple values of `trees`.
41+
For this engine, tuning over `trees` is very efficient since the same model object can be used to make predictions over multiple values of `trees`.
4342

4443
Note that `xgboost` models require that non-numeric predictors (e.g., factors) must be converted to dummy variables or some other numeric representation. By default, when using `fit()` with `xgboost`, a one-hot encoding is used to convert factor predictors to indicator variables.
4544

@@ -89,7 +88,7 @@ get_defaults_boost_tree <- function() {
8988
"boost_tree", "xgboost", "tree_depth", "max_depth", get_arg("parsnip", "xgb_train", "max_depth"),
9089
"boost_tree", "xgboost", "trees", "nrounds", get_arg("parsnip", "xgb_train", "nrounds"),
9190
"boost_tree", "xgboost", "learn_rate", "eta", get_arg("parsnip", "xgb_train", "eta"),
92-
"boost_tree", "xgboost", "mtry", "colsample_bytree", get_arg("parsnip", "xgb_train", "colsample_bytree"),
91+
"boost_tree", "xgboost", "mtry", "colsample_bynode", get_arg("parsnip", "xgb_train", "colsample_bynode"),
9392
"boost_tree", "xgboost", "min_n", "min_child_weight", get_arg("parsnip", "xgb_train", "min_child_weight"),
9493
"boost_tree", "xgboost", "loss_reduction", "gamma", get_arg("parsnip", "xgb_train", "gamma"),
9594
"boost_tree", "xgboost", "sample_size", "subsample", get_arg("parsnip", "xgb_train", "subsample"),

man/xgb_train.Rd

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,9 +414,9 @@ test_that('argument checks for data dimensions', {
414414
xy_fit <- spec %>% fit_xy(x = penguins_dummy, y = penguins$species),
415415
"1000 samples were requested"
416416
)
417-
expect_equal(f_fit$fit$params$colsample_bytree, 1)
417+
expect_equal(f_fit$fit$params$colsample_bynode, 1)
418418
expect_equal(f_fit$fit$params$min_child_weight, nrow(penguins))
419-
expect_equal(xy_fit$fit$params$colsample_bytree, 1)
419+
expect_equal(xy_fit$fit$params$colsample_bynode, 1)
420420
expect_equal(xy_fit$fit$params$min_child_weight, nrow(penguins))
421421

422422
})
@@ -482,3 +482,16 @@ test_that("fit and prediction with `event_level`", {
482482
expect_equal(pred_p_2[[".pred_male"]], pred_xgb_2)
483483

484484
})
485+
486+
test_that("mtry parameters", {
487+
skip_if_not_installed("xgboost")
488+
fit <-
489+
boost_tree(mtry = .7, trees = 4) %>%
490+
set_engine("xgboost") %>%
491+
set_mode("regression") %>%
492+
fit(mpg ~ ., data = mtcars)
493+
expect_equal(fit$fit$params$colsample_bytree, 1)
494+
expect_equal(fit$fit$params$colsample_bynode, 0.7)
495+
})
496+
497+

0 commit comments

Comments
 (0)