23
23
from botorch .acquisition .knowledge_gradient import qKnowledgeGradient
24
24
from botorch .exceptions import InputDataError , UnsupportedError
25
25
from botorch .exceptions .warnings import OptimizationWarning
26
- from botorch .generation .gen import gen_candidates_scipy
26
+ from botorch .generation .gen import gen_candidates_scipy , TGenCandidates
27
27
from botorch .logging import logger
28
28
from botorch .optim .initializers import (
29
29
gen_batch_initial_conditions ,
30
30
gen_one_shot_kg_initial_conditions ,
31
31
)
32
32
from botorch .optim .stopping import ExpMAStoppingCriterion
33
+ from botorch .optim .utils import _filter_kwargs
33
34
from torch import Tensor
34
35
35
36
INIT_OPTION_KEYS = {
@@ -64,6 +65,7 @@ def optimize_acqf(
64
65
post_processing_func : Optional [Callable [[Tensor ], Tensor ]] = None ,
65
66
batch_initial_conditions : Optional [Tensor ] = None ,
66
67
return_best_only : bool = True ,
68
+ gen_candidates : Optional [TGenCandidates ] = None ,
67
69
sequential : bool = False ,
68
70
** kwargs : Any ,
69
71
) -> Tuple [Tensor , Tensor ]:
@@ -103,6 +105,12 @@ def optimize_acqf(
103
105
this if you do not want to use default initialization strategy.
104
106
return_best_only: If False, outputs the solutions corresponding to all
105
107
random restart initializations of the optimization.
108
+ gen_candidates: A callable for generating candidates (and their associated
109
+ acquisition values) given a tensor of initial conditions and an
110
+ acquisition function. Other common inputs include lower and upper bounds
111
+ and a dictionary of options, but refer to the documentation of specific
112
+ generation functions (e.g gen_candidates_scipy and gen_candidates_torch)
113
+ for method-specific inputs. Default: `gen_candidates_scipy`
106
114
sequential: If False, uses joint optimization, otherwise uses sequential
107
115
optimization.
108
116
kwargs: Additonal keyword arguments.
@@ -134,6 +142,9 @@ def optimize_acqf(
134
142
"""
135
143
start_time : float = time .monotonic ()
136
144
timeout_sec = kwargs .pop ("timeout_sec" , None )
145
+ # using a default of None simplifies unit testing
146
+ if gen_candidates is None :
147
+ gen_candidates = gen_candidates_scipy
137
148
138
149
if inequality_constraints is None :
139
150
if not (bounds .ndim == 2 and bounds .shape [0 ] == 2 ):
@@ -229,6 +240,7 @@ def optimize_acqf(
229
240
sequential = False ,
230
241
ic_generator = ic_gen ,
231
242
timeout_sec = timeout_sec ,
243
+ gen_candidates = gen_candidates ,
232
244
)
233
245
234
246
candidate_list .append (candidate )
@@ -277,6 +289,11 @@ def optimize_acqf(
277
289
batch_limit : int = options .get (
278
290
"batch_limit" , num_restarts if not nonlinear_inequality_constraints else 1
279
291
)
292
+ has_parameter_constraints = (
293
+ inequality_constraints is not None
294
+ or equality_constraints is not None
295
+ or nonlinear_inequality_constraints is not None
296
+ )
280
297
281
298
def _optimize_batch_candidates (
282
299
timeout_sec : Optional [float ],
@@ -288,24 +305,36 @@ def _optimize_batch_candidates(
288
305
if timeout_sec is not None :
289
306
timeout_sec = (timeout_sec - start_time ) / len (batched_ics )
290
307
291
- scipy_kws = {
308
+ gen_kwargs = {
292
309
"acquisition_function" : acq_function ,
293
310
"lower_bounds" : None if bounds [0 ].isinf ().all () else bounds [0 ],
294
311
"upper_bounds" : None if bounds [1 ].isinf ().all () else bounds [1 ],
295
312
"options" : {k : v for k , v in options .items () if k not in INIT_OPTION_KEYS },
296
- "inequality_constraints" : inequality_constraints ,
297
- "equality_constraints" : equality_constraints ,
298
- "nonlinear_inequality_constraints" : nonlinear_inequality_constraints ,
299
313
"fixed_features" : fixed_features ,
300
314
"timeout_sec" : timeout_sec ,
301
315
}
302
316
317
+ if has_parameter_constraints :
318
+ # only add parameter constraints to gen_kwargs if they are specified
319
+ # to avoid unnecessary warnings in _filter_kwargs
320
+ gen_kwargs .update (
321
+ {
322
+ "inequality_constraints" : inequality_constraints ,
323
+ "equality_constraints" : equality_constraints ,
324
+ # the line is too long
325
+ "nonlinear_inequality_constraints" : (
326
+ nonlinear_inequality_constraints
327
+ ),
328
+ }
329
+ )
330
+ filtered_gen_kwargs = _filter_kwargs (gen_candidates , ** gen_kwargs )
331
+
303
332
for i , batched_ics_ in enumerate (batched_ics ):
304
333
# optimize using random restart optimization
305
334
with warnings .catch_warnings (record = True ) as ws :
306
335
warnings .simplefilter ("always" , category = OptimizationWarning )
307
- batch_candidates_curr , batch_acq_values_curr = gen_candidates_scipy (
308
- initial_conditions = batched_ics_ , ** scipy_kws
336
+ batch_candidates_curr , batch_acq_values_curr = gen_candidates (
337
+ initial_conditions = batched_ics_ , ** filtered_gen_kwargs
309
338
)
310
339
opt_warnings += ws
311
340
batch_candidates_list .append (batch_candidates_curr )
0 commit comments