Skip to content

Commit f692120

Browse files
sdaultonfacebook-github-bot
authored andcommitted
pass gen_candidates callable in optimize_acqf (#1655)
Summary: Pull Request resolved: #1655 see title. This will support using stochastic optimization Differential Revision: D41629164 fbshipit-source-id: 0f31bdc3392f47546da31183fa2166bf18ec174b
1 parent 076af96 commit f692120

File tree

3 files changed

+336
-231
lines changed

3 files changed

+336
-231
lines changed

botorch/generation/gen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737

3838
logger = _get_logger()
3939

40+
TGenCandidates = Callable[[Tensor, AcquisitionFunction, Any], Tuple[Tensor, Tensor]]
41+
4042

4143
def gen_candidates_scipy(
4244
initial_conditions: Tensor,
@@ -151,7 +153,6 @@ def gen_candidates_scipy(
151153
clamped_candidates
152154
)
153155
return clamped_candidates, batch_acquisition
154-
155156
clamped_candidates = columnwise_clamp(
156157
X=initial_conditions, lower=lower_bounds, upper=upper_bounds
157158
)
@@ -359,7 +360,6 @@ def gen_candidates_torch(
359360
clamped_candidates
360361
)
361362
return clamped_candidates, batch_acquisition
362-
363363
_clamp = partial(columnwise_clamp, lower=lower_bounds, upper=upper_bounds)
364364
clamped_candidates = _clamp(initial_conditions).requires_grad_(True)
365365
_optimizer = optimizer(params=[clamped_candidates], lr=options.get("lr", 0.025))

botorch/optim/optimize.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@
2323
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
2424
from botorch.exceptions import InputDataError, UnsupportedError
2525
from botorch.exceptions.warnings import OptimizationWarning
26-
from botorch.generation.gen import gen_candidates_scipy
26+
from botorch.generation.gen import gen_candidates_scipy, TGenCandidates
2727
from botorch.logging import logger
2828
from botorch.optim.initializers import (
2929
gen_batch_initial_conditions,
3030
gen_one_shot_kg_initial_conditions,
3131
)
3232
from botorch.optim.stopping import ExpMAStoppingCriterion
33+
from botorch.optim.utils import _filter_kwargs
3334
from torch import Tensor
3435

3536
INIT_OPTION_KEYS = {
@@ -64,6 +65,7 @@ def optimize_acqf(
6465
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
6566
batch_initial_conditions: Optional[Tensor] = None,
6667
return_best_only: bool = True,
68+
gen_candidates: TGenCandidates = gen_candidates_scipy,
6769
sequential: bool = False,
6870
**kwargs: Any,
6971
) -> Tuple[Tensor, Tensor]:
@@ -103,6 +105,12 @@ def optimize_acqf(
103105
this if you do not want to use default initialization strategy.
104106
return_best_only: If False, outputs the solutions corresponding to all
105107
random restart initializations of the optimization.
108+
gen_candidates: A callable for generating candidates (and their associated
109+
acquisition values) given a tensor of initial conditions and an
110+
acquisition function. Other common inputs include lower and upper bounds
111+
and a dictionary of options, but refer to the documentation of specific
112+
generation functions (e.g gen_candidates_scipy and gen_candidates_torch)
113+
for method-specific inputs. Default: `gen_candidates_scipy`
106114
sequential: If False, uses joint optimization, otherwise uses sequential
107115
optimization.
108116
kwargs: Additonal keyword arguments.
@@ -214,6 +222,7 @@ def optimize_acqf(
214222
sequential=False,
215223
ic_generator=ic_gen,
216224
timeout_sec=timeout_sec,
225+
gen_candidates=gen_candidates,
217226
)
218227

219228
candidate_list.append(candidate)
@@ -262,6 +271,11 @@ def optimize_acqf(
262271
batch_limit: int = options.get(
263272
"batch_limit", num_restarts if not nonlinear_inequality_constraints else 1
264273
)
274+
has_parameter_constraints = (
275+
inequality_constraints is not None
276+
or equality_constraints is not None
277+
or nonlinear_inequality_constraints is not None
278+
)
265279

266280
def _optimize_batch_candidates(
267281
timeout_sec: Optional[float],
@@ -273,24 +287,33 @@ def _optimize_batch_candidates(
273287
if timeout_sec is not None:
274288
timeout_sec = (timeout_sec - start_time) / len(batched_ics)
275289

276-
scipy_kws = {
290+
gen_kwargs = {
277291
"acquisition_function": acq_function,
278292
"lower_bounds": None if bounds[0].isinf().all() else bounds[0],
279293
"upper_bounds": None if bounds[1].isinf().all() else bounds[1],
280294
"options": {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS},
281-
"inequality_constraints": inequality_constraints,
282-
"equality_constraints": equality_constraints,
283-
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
284295
"fixed_features": fixed_features,
285296
"timeout_sec": timeout_sec,
286297
}
287298

299+
if has_parameter_constraints:
300+
# only add parameter constraints to gen_kwargs if they are specified
301+
# to avoid unnecessary warnings in _filter_kwargs
302+
gen_kwargs.update(
303+
{
304+
"inequality_constraints": inequality_constraints,
305+
"equality_constraints": equality_constraints,
306+
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
307+
}
308+
)
309+
filtered_gen_kwargs = _filter_kwargs(gen_candidates, **gen_kwargs)
310+
288311
for i, batched_ics_ in enumerate(batched_ics):
289312
# optimize using random restart optimization
290313
with warnings.catch_warnings(record=True) as ws:
291314
warnings.simplefilter("always", category=OptimizationWarning)
292-
batch_candidates_curr, batch_acq_values_curr = gen_candidates_scipy(
293-
initial_conditions=batched_ics_, **scipy_kws
315+
batch_candidates_curr, batch_acq_values_curr = gen_candidates(
316+
initial_conditions=batched_ics_, **filtered_gen_kwargs
294317
)
295318
opt_warnings += ws
296319
batch_candidates_list.append(batch_candidates_curr)

0 commit comments

Comments
 (0)