Skip to content

Commit 3b9906f

Browse files
sdaultonfacebook-github-bot
authored andcommitted
pass gen_candidates callable in optimize_acqf
Summary: see title. This will support using stochastic optimization Differential Revision: D41629164 fbshipit-source-id: 95cd4cf5612cfce604cede1382db07f8d187fb59
1 parent 52676e9 commit 3b9906f

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

botorch/generation/gen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737

3838
logger = _get_logger()
3939

40+
TGenCandidates = Callable[[Tensor, AcquisitionFunction, Any], Tuple[Tensor, Tensor]]
41+
4042

4143
def gen_candidates_scipy(
4244
initial_conditions: Tensor,
@@ -49,6 +51,7 @@ def gen_candidates_scipy(
4951
options: Optional[Dict[str, Any]] = None,
5052
fixed_features: Optional[Dict[int, Optional[float]]] = None,
5153
timeout_sec: Optional[float] = None,
54+
**kwargs,
5255
) -> Tuple[Tensor, Tensor]:
5356
r"""Generate a set of candidates using `scipy.optimize.minimize`.
5457
@@ -281,6 +284,7 @@ def gen_candidates_torch(
281284
callback: Optional[Callable[[int, Tensor, Tensor], NoReturn]] = None,
282285
fixed_features: Optional[Dict[int, Optional[float]]] = None,
283286
timeout_sec: Optional[float] = None,
287+
**kwargs,
284288
) -> Tuple[Tensor, Tensor]:
285289
r"""Generate a set of candidates using a `torch.optim` optimizer.
286290

botorch/optim/optimize.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
2424
from botorch.exceptions import InputDataError, UnsupportedError
2525
from botorch.exceptions.warnings import OptimizationWarning
26-
from botorch.generation.gen import gen_candidates_scipy
26+
from botorch.generation.gen import gen_candidates_scipy, TGenCandidates
2727
from botorch.logging import logger
2828
from botorch.optim.initializers import (
2929
gen_batch_initial_conditions,
@@ -64,6 +64,7 @@ def optimize_acqf(
6464
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
6565
batch_initial_conditions: Optional[Tensor] = None,
6666
return_best_only: bool = True,
67+
gen_candidates: TGenCandidates = gen_candidates_scipy,
6768
sequential: bool = False,
6869
**kwargs: Any,
6970
) -> Tuple[Tensor, Tensor]:
@@ -103,6 +104,8 @@ def optimize_acqf(
103104
this if you do not want to use default initialization strategy.
104105
return_best_only: If False, outputs the solutions corresponding to all
105106
random restart initializations of the optimization.
107+
gen_candidates: A callable for generating candidates given initial
108+
conditions. Default: `gen_candidates_scipy`
106109
sequential: If False, uses joint optimization, otherwise uses sequential
107110
optimization.
108111
kwargs: Additonal keyword arguments.
@@ -273,7 +276,7 @@ def _optimize_batch_candidates(
273276
if timeout_sec is not None:
274277
timeout_sec = (timeout_sec - start_time) / len(batched_ics)
275278

276-
scipy_kws = {
279+
gen_kws = {
277280
"acquisition_function": acq_function,
278281
"lower_bounds": None if bounds[0].isinf().all() else bounds[0],
279282
"upper_bounds": None if bounds[1].isinf().all() else bounds[1],
@@ -289,8 +292,8 @@ def _optimize_batch_candidates(
289292
# optimize using random restart optimization
290293
with warnings.catch_warnings(record=True) as ws:
291294
warnings.simplefilter("always", category=OptimizationWarning)
292-
batch_candidates_curr, batch_acq_values_curr = gen_candidates_scipy(
293-
initial_conditions=batched_ics_, **scipy_kws
295+
batch_candidates_curr, batch_acq_values_curr = gen_candidates(
296+
initial_conditions=batched_ics_, **gen_kwargs
294297
)
295298
opt_warnings += ws
296299
batch_candidates_list.append(batch_candidates_curr)

0 commit comments

Comments
 (0)