Skip to content

Commit bec0423

Browse files
author
‘topepo’
committed
enable parsnip to work with functions wit parameterized labels
1 parent c0d917e commit bec0423

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

R/extract.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ extract_parameter_set_dials.model_spec <- function(x, ...) {
113113
eval_call_info <- function(x) {
114114
if (!is.null(x)) {
115115
# Look for other options
116-
allowed_opts <- c("range", "trans", "values")
116+
allowed_opts <- c("range", "trans", "values", "label")
117117
if (any(names(x) %in% allowed_opts)) {
118118
opts <- x[names(x) %in% allowed_opts]
119119
} else {

R/tunable.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ brulee_mlp_engine_args <-
199199
~name, ~call_info,
200200
"momentum", list(pkg = "dials", fun = "momentum", range = c(0.5, 0.95)),
201201
"batch_size", list(pkg = "dials", fun = "batch_size", range = c(3, 10)),
202-
"hidden_units_2", list(pkg = "dials", fun = "hidden_units"),
203-
"activation_2", list(pkg = "dials", fun = "activation"),
202+
"hidden_units_2", list(pkg = "dials", fun = "hidden_units", label = "# Hidden Units (layer 2)"),
203+
"activation_2", list(pkg = "dials", fun = "activation", label = "Activation Function (layer 2)"),
204204
"stop_iter", list(pkg = "dials", fun = "stop_iter"),
205205
"class_weights", list(pkg = "dials", fun = "class_weights"),
206206
"decay", list(pkg = "dials", fun = "rate_decay"),
@@ -364,6 +364,7 @@ tunable.mlp <- function(x, ...) {
364364
res$call_info[res$name == "epochs"] <-
365365
list(list(pkg = "dials", fun = "epochs", range = c(5L, 500L)))
366366
} else if (x$engine == "brulee_two_layer") {
367+
rlang::check_installed("brulee", version = "0.3.0.9000")
367368
res <- add_engine_parameters(res, brulee_mlp_engine_args)
368369
res$call_info[res$name == "learn_rate"] <-
369370
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/2)))

0 commit comments

Comments
 (0)