23
23
from botorch .acquisition .knowledge_gradient import qKnowledgeGradient
24
24
from botorch .exceptions import InputDataError , UnsupportedError
25
25
from botorch .exceptions .warnings import OptimizationWarning
26
- from botorch .generation .gen import gen_candidates_scipy
26
+ from botorch .generation .gen import gen_candidates_scipy , TGenCandidates
27
27
from botorch .logging import logger
28
28
from botorch .optim .initializers import (
29
29
gen_batch_initial_conditions ,
30
30
gen_one_shot_kg_initial_conditions ,
31
31
)
32
32
from botorch .optim .stopping import ExpMAStoppingCriterion
33
+ from botorch .optim .utils import _filter_kwargs
33
34
from torch import Tensor
34
35
35
36
INIT_OPTION_KEYS = {
@@ -64,6 +65,7 @@ def optimize_acqf(
64
65
post_processing_func : Optional [Callable [[Tensor ], Tensor ]] = None ,
65
66
batch_initial_conditions : Optional [Tensor ] = None ,
66
67
return_best_only : bool = True ,
68
+ gen_candidates : TGenCandidates = gen_candidates_scipy ,
67
69
sequential : bool = False ,
68
70
** kwargs : Any ,
69
71
) -> Tuple [Tensor , Tensor ]:
@@ -103,6 +105,12 @@ def optimize_acqf(
103
105
this if you do not want to use default initialization strategy.
104
106
return_best_only: If False, outputs the solutions corresponding to all
105
107
random restart initializations of the optimization.
108
+ gen_candidates: A callable for generating candidates (and their associated
109
+ acquisition values) given a tensor of initial conditions and an
110
+ acquisition function. Other common inputs include lower and upper bounds
111
+ and a dictionary of options, but refer to the documentation of specific
112
+ generation functions (e.g gen_candidates_scipy and gen_candidates_torch)
113
+ for method-specific inputs. Default: `gen_candidates_scipy`
106
114
sequential: If False, uses joint optimization, otherwise uses sequential
107
115
optimization.
108
116
kwargs: Additonal keyword arguments.
@@ -214,6 +222,7 @@ def optimize_acqf(
214
222
sequential = False ,
215
223
ic_generator = ic_gen ,
216
224
timeout_sec = timeout_sec ,
225
+ gen_candidates = gen_candidates ,
217
226
)
218
227
219
228
candidate_list .append (candidate )
@@ -262,6 +271,11 @@ def optimize_acqf(
262
271
batch_limit : int = options .get (
263
272
"batch_limit" , num_restarts if not nonlinear_inequality_constraints else 1
264
273
)
274
+ has_parameter_constraints = (
275
+ inequality_constraints is not None
276
+ or equality_constraints is not None
277
+ or nonlinear_inequality_constraints is not None
278
+ )
265
279
266
280
def _optimize_batch_candidates (
267
281
timeout_sec : Optional [float ],
@@ -273,24 +287,33 @@ def _optimize_batch_candidates(
273
287
if timeout_sec is not None :
274
288
timeout_sec = (timeout_sec - start_time ) / len (batched_ics )
275
289
276
- scipy_kws = {
290
+ gen_kwargs = {
277
291
"acquisition_function" : acq_function ,
278
292
"lower_bounds" : None if bounds [0 ].isinf ().all () else bounds [0 ],
279
293
"upper_bounds" : None if bounds [1 ].isinf ().all () else bounds [1 ],
280
294
"options" : {k : v for k , v in options .items () if k not in INIT_OPTION_KEYS },
281
- "inequality_constraints" : inequality_constraints ,
282
- "equality_constraints" : equality_constraints ,
283
- "nonlinear_inequality_constraints" : nonlinear_inequality_constraints ,
284
295
"fixed_features" : fixed_features ,
285
296
"timeout_sec" : timeout_sec ,
286
297
}
287
298
299
+ if has_parameter_constraints :
300
+ # only add parameter constraints to gen_kwargs if they are specified
301
+ # to avoid unnecessary warnings in _filter_kwargs
302
+ gen_kwargs .update (
303
+ {
304
+ "inequality_constraints" : inequality_constraints ,
305
+ "equality_constraints" : equality_constraints ,
306
+ "nonlinear_inequality_constraints" : nonlinear_inequality_constraints ,
307
+ }
308
+ )
309
+ filtered_gen_kwargs = _filter_kwargs (gen_candidates , ** gen_kwargs )
310
+
288
311
for i , batched_ics_ in enumerate (batched_ics ):
289
312
# optimize using random restart optimization
290
313
with warnings .catch_warnings (record = True ) as ws :
291
314
warnings .simplefilter ("always" , category = OptimizationWarning )
292
- batch_candidates_curr , batch_acq_values_curr = gen_candidates_scipy (
293
- initial_conditions = batched_ics_ , ** scipy_kws
315
+ batch_candidates_curr , batch_acq_values_curr = gen_candidates (
316
+ initial_conditions = batched_ics_ , ** filtered_gen_kwargs
294
317
)
295
318
opt_warnings += ws
296
319
batch_candidates_list .append (batch_candidates_curr )
0 commit comments