23
23
from botorch .acquisition .knowledge_gradient import qKnowledgeGradient
24
24
from botorch .exceptions import InputDataError , UnsupportedError
25
25
from botorch .exceptions .warnings import OptimizationWarning
26
- from botorch .generation .gen import gen_candidates_scipy
26
+ from botorch .generation .gen import gen_candidates_scipy , TGenCandidates
27
27
from botorch .logging import logger
28
28
from botorch .optim .initializers import (
29
29
gen_batch_initial_conditions ,
@@ -64,6 +64,7 @@ def optimize_acqf(
64
64
post_processing_func : Optional [Callable [[Tensor ], Tensor ]] = None ,
65
65
batch_initial_conditions : Optional [Tensor ] = None ,
66
66
return_best_only : bool = True ,
67
+ gen_candidates : TGenCandidates = gen_candidates_scipy ,
67
68
sequential : bool = False ,
68
69
** kwargs : Any ,
69
70
) -> Tuple [Tensor , Tensor ]:
@@ -103,6 +104,8 @@ def optimize_acqf(
103
104
this if you do not want to use default initialization strategy.
104
105
return_best_only: If False, outputs the solutions corresponding to all
105
106
random restart initializations of the optimization.
107
+ gen_candidates: A callable for generating candidates given initial
108
+ conditions. Default: `gen_candidates_scipy`
106
109
sequential: If False, uses joint optimization, otherwise uses sequential
107
110
optimization.
108
111
kwargs: Additonal keyword arguments.
@@ -273,7 +276,7 @@ def _optimize_batch_candidates(
273
276
if timeout_sec is not None :
274
277
timeout_sec = (timeout_sec - start_time ) / len (batched_ics )
275
278
276
- scipy_kws = {
279
+ gen_kws = {
277
280
"acquisition_function" : acq_function ,
278
281
"lower_bounds" : None if bounds [0 ].isinf ().all () else bounds [0 ],
279
282
"upper_bounds" : None if bounds [1 ].isinf ().all () else bounds [1 ],
@@ -289,8 +292,8 @@ def _optimize_batch_candidates(
289
292
# optimize using random restart optimization
290
293
with warnings .catch_warnings (record = True ) as ws :
291
294
warnings .simplefilter ("always" , category = OptimizationWarning )
292
- batch_candidates_curr , batch_acq_values_curr = gen_candidates_scipy (
293
- initial_conditions = batched_ics_ , ** scipy_kws
295
+ batch_candidates_curr , batch_acq_values_curr = gen_candidates (
296
+ initial_conditions = batched_ics_ , ** gen_kwargs
294
297
)
295
298
opt_warnings += ws
296
299
batch_candidates_list .append (batch_candidates_curr )
0 commit comments