@@ -199,8 +199,8 @@ brulee_mlp_engine_args <-
199
199
~ name , ~ call_info ,
200
200
" momentum" , list (pkg = " dials" , fun = " momentum" , range = c(0.5 , 0.95 )),
201
201
" batch_size" , list (pkg = " dials" , fun = " batch_size" , range = c(3 , 10 )),
202
- " hidden_units_2" , list (pkg = " dials" , fun = " hidden_units" ),
203
- " activation_2" , list (pkg = " dials" , fun = " activation" ),
202
+ " hidden_units_2" , list (pkg = " dials" , fun = " hidden_units" , label = " # Hidden Units (layer 2) " ),
203
+ " activation_2" , list (pkg = " dials" , fun = " activation" , label = " Activation Function (layer 2) " ),
204
204
" stop_iter" , list (pkg = " dials" , fun = " stop_iter" ),
205
205
" class_weights" , list (pkg = " dials" , fun = " class_weights" ),
206
206
" decay" , list (pkg = " dials" , fun = " rate_decay" ),
@@ -364,6 +364,7 @@ tunable.mlp <- function(x, ...) {
364
364
res $ call_info [res $ name == " epochs" ] <-
365
365
list (list (pkg = " dials" , fun = " epochs" , range = c(5L , 500L )))
366
366
} else if (x $ engine == " brulee_two_layer" ) {
367
+ rlang :: check_installed(" brulee" , version = " 0.3.0.9000" )
367
368
res <- add_engine_parameters(res , brulee_mlp_engine_args )
368
369
res $ call_info [res $ name == " learn_rate" ] <-
369
370
list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
0 commit comments