Skip to content

Commit 89f923d

Browse files
sdaultonfacebook-github-bot
authored andcommitted
pass gen_candidates callable in optimize_acqf (#1655)
Summary: Pull Request resolved: #1655 see title. This will support using stochastic optimization Reviewed By: esantorella Differential Revision: D41629164 fbshipit-source-id: c2edab3d6a000ef41844c86558de3e5ebeb3a2ce
1 parent 58090d3 commit 89f923d

File tree

4 files changed

+358
-231
lines changed

4 files changed

+358
-231
lines changed

botorch/generation/gen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737

3838
logger = _get_logger()
3939

40+
TGenCandidates = Callable[[Tensor, AcquisitionFunction, Any], Tuple[Tensor, Tensor]]
41+
4042

4143
def gen_candidates_scipy(
4244
initial_conditions: Tensor,
@@ -152,7 +154,6 @@ def gen_candidates_scipy(
152154
clamped_candidates
153155
)
154156
return clamped_candidates, batch_acquisition
155-
156157
clamped_candidates = columnwise_clamp(
157158
X=initial_conditions, lower=lower_bounds, upper=upper_bounds
158159
)
@@ -360,7 +361,6 @@ def gen_candidates_torch(
360361
clamped_candidates
361362
)
362363
return clamped_candidates, batch_acquisition
363-
364364
_clamp = partial(columnwise_clamp, lower=lower_bounds, upper=upper_bounds)
365365
clamped_candidates = _clamp(initial_conditions).requires_grad_(True)
366366
_optimizer = optimizer(params=[clamped_candidates], lr=options.get("lr", 0.025))

botorch/optim/optimize.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@
2323
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
2424
from botorch.exceptions import InputDataError, UnsupportedError
2525
from botorch.exceptions.warnings import OptimizationWarning
26-
from botorch.generation.gen import gen_candidates_scipy
26+
from botorch.generation.gen import gen_candidates_scipy, TGenCandidates
2727
from botorch.logging import logger
2828
from botorch.optim.initializers import (
2929
gen_batch_initial_conditions,
3030
gen_one_shot_kg_initial_conditions,
3131
)
3232
from botorch.optim.stopping import ExpMAStoppingCriterion
33+
from botorch.optim.utils import _filter_kwargs
3334
from torch import Tensor
3435

3536
INIT_OPTION_KEYS = {
@@ -64,6 +65,7 @@ def optimize_acqf(
6465
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
6566
batch_initial_conditions: Optional[Tensor] = None,
6667
return_best_only: bool = True,
68+
gen_candidates: Optional[TGenCandidates] = None,
6769
sequential: bool = False,
6870
**kwargs: Any,
6971
) -> Tuple[Tensor, Tensor]:
@@ -103,6 +105,12 @@ def optimize_acqf(
103105
this if you do not want to use default initialization strategy.
104106
return_best_only: If False, outputs the solutions corresponding to all
105107
random restart initializations of the optimization.
108+
gen_candidates: A callable for generating candidates (and their associated
109+
acquisition values) given a tensor of initial conditions and an
110+
acquisition function. Other common inputs include lower and upper bounds
111+
and a dictionary of options, but refer to the documentation of specific
112+
generation functions (e.g gen_candidates_scipy and gen_candidates_torch)
113+
for method-specific inputs. Default: `gen_candidates_scipy`
106114
sequential: If False, uses joint optimization, otherwise uses sequential
107115
optimization.
108116
kwargs: Additonal keyword arguments.
@@ -134,6 +142,9 @@ def optimize_acqf(
134142
"""
135143
start_time: float = time.monotonic()
136144
timeout_sec = kwargs.pop("timeout_sec", None)
145+
# using a default of None simplifies unit testing
146+
if gen_candidates is None:
147+
gen_candidates = gen_candidates_scipy
137148

138149
if inequality_constraints is None:
139150
if not (bounds.ndim == 2 and bounds.shape[0] == 2):
@@ -229,6 +240,7 @@ def optimize_acqf(
229240
sequential=False,
230241
ic_generator=ic_gen,
231242
timeout_sec=timeout_sec,
243+
gen_candidates=gen_candidates,
232244
)
233245

234246
candidate_list.append(candidate)
@@ -277,6 +289,11 @@ def optimize_acqf(
277289
batch_limit: int = options.get(
278290
"batch_limit", num_restarts if not nonlinear_inequality_constraints else 1
279291
)
292+
has_parameter_constraints = (
293+
inequality_constraints is not None
294+
or equality_constraints is not None
295+
or nonlinear_inequality_constraints is not None
296+
)
280297

281298
def _optimize_batch_candidates(
282299
timeout_sec: Optional[float],
@@ -288,24 +305,36 @@ def _optimize_batch_candidates(
288305
if timeout_sec is not None:
289306
timeout_sec = (timeout_sec - start_time) / len(batched_ics)
290307

291-
scipy_kws = {
308+
gen_kwargs = {
292309
"acquisition_function": acq_function,
293310
"lower_bounds": None if bounds[0].isinf().all() else bounds[0],
294311
"upper_bounds": None if bounds[1].isinf().all() else bounds[1],
295312
"options": {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS},
296-
"inequality_constraints": inequality_constraints,
297-
"equality_constraints": equality_constraints,
298-
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
299313
"fixed_features": fixed_features,
300314
"timeout_sec": timeout_sec,
301315
}
302316

317+
if has_parameter_constraints:
318+
# only add parameter constraints to gen_kwargs if they are specified
319+
# to avoid unnecessary warnings in _filter_kwargs
320+
gen_kwargs.update(
321+
{
322+
"inequality_constraints": inequality_constraints,
323+
"equality_constraints": equality_constraints,
324+
# the line is too long
325+
"nonlinear_inequality_constraints": (
326+
nonlinear_inequality_constraints
327+
),
328+
}
329+
)
330+
filtered_gen_kwargs = _filter_kwargs(gen_candidates, **gen_kwargs)
331+
303332
for i, batched_ics_ in enumerate(batched_ics):
304333
# optimize using random restart optimization
305334
with warnings.catch_warnings(record=True) as ws:
306335
warnings.simplefilter("always", category=OptimizationWarning)
307-
batch_candidates_curr, batch_acq_values_curr = gen_candidates_scipy(
308-
initial_conditions=batched_ics_, **scipy_kws
336+
batch_candidates_curr, batch_acq_values_curr = gen_candidates(
337+
initial_conditions=batched_ics_, **filtered_gen_kwargs
309338
)
310339
opt_warnings += ws
311340
batch_candidates_list.append(batch_candidates_curr)

test/acquisition/test_knowledge_gradient.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
)
2626
from botorch.acquisition.utils import project_to_sample_points
2727
from botorch.exceptions.errors import UnsupportedError
28+
from botorch.generation.gen import gen_candidates_scipy
2829
from botorch.models import SingleTaskGP
2930
from botorch.optim.optimize import optimize_acqf
31+
from botorch.optim.utils import _filter_kwargs
3032
from botorch.posteriors.gpytorch import GPyTorchPosterior
3133
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
3234
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
@@ -593,7 +595,13 @@ def test_optimize_w_posterior_transform(self):
593595
torch.zeros(2, n_f + 1, 2, **tkwargs),
594596
torch.zeros(2, **tkwargs),
595597
),
598+
), mock.patch(
599+
f"{optimize_acqf.__module__}._filter_kwargs",
600+
wraps=lambda f, **kwargs: _filter_kwargs(
601+
function=gen_candidates_scipy, **kwargs
602+
),
596603
):
604+
597605
candidate, value = optimize_acqf(
598606
acq_function=kg,
599607
bounds=bounds,

0 commit comments

Comments
 (0)