Skip to content

utils.sigmoid with log and fat options #1938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions botorch/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
AnalyticAcquisitionFunction,
ConstrainedExpectedImprovement,
ExpectedImprovement,
LogExpectedImprovement,
LogNoisyExpectedImprovement,
NoisyExpectedImprovement,
PosteriorMean,
ProbabilityOfImprovement,
Expand All @@ -32,6 +34,10 @@
qKnowledgeGradient,
qMultiFidelityKnowledgeGradient,
)
from botorch.acquisition.logei import (
LogImprovementMCAcquisitionFunction,
qLogExpectedImprovement,
)
from botorch.acquisition.max_value_entropy_search import (
MaxValueBase,
qLowerBoundMaxValueEntropy,
Expand All @@ -46,6 +52,7 @@
qProbabilityOfImprovement,
qSimpleRegret,
qUpperConfidenceBound,
SampleReducingMCAcquisitionFunction,
)
from botorch.acquisition.multi_step_lookahead import qMultiStepLookahead
from botorch.acquisition.objective import (
Expand All @@ -71,6 +78,8 @@
"AnalyticExpectedUtilityOfBestOption",
"ConstrainedExpectedImprovement",
"ExpectedImprovement",
"LogExpectedImprovement",
"LogNoisyExpectedImprovement",
"FixedFeatureAcquisitionFunction",
"GenericCostAwareUtility",
"InverseCostWeightedUtility",
Expand All @@ -85,6 +94,8 @@
"UpperConfidenceBound",
"qAnalyticProbabilityOfImprovement",
"qExpectedImprovement",
"LogImprovementMCAcquisitionFunction",
"qLogExpectedImprovement",
"qKnowledgeGradient",
"MaxValueBase",
"qMultiFidelityKnowledgeGradient",
Expand All @@ -104,6 +115,7 @@
"LearnedObjective",
"LinearMCObjective",
"MCAcquisitionFunction",
"SampleReducingMCAcquisitionFunction",
"MCAcquisitionObjective",
"ScalarizedPosteriorTransform",
"get_acquisition_function",
Expand Down
3 changes: 2 additions & 1 deletion botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
qKnowledgeGradient,
qMultiFidelityKnowledgeGradient,
)
from botorch.acquisition.logei import qLogExpectedImprovement
from botorch.acquisition.max_value_entropy_search import (
qMaxValueEntropy,
qMultiFidelityMaxValueEntropy,
Expand Down Expand Up @@ -449,7 +450,7 @@ def construct_inputs_qSimpleRegret(
)


@acqf_input_constructor(qExpectedImprovement)
@acqf_input_constructor(qExpectedImprovement, qLogExpectedImprovement)
def construct_inputs_qEI(
model: Model,
training_data: MaybeDict[SupervisedDataset],
Expand Down
261 changes: 261 additions & 0 deletions botorch/acquisition/logei.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Batch implementations of the LogEI family of improvements-based acquisition functions.
"""


from __future__ import annotations

from functools import partial

from typing import Callable, List, Optional, TypeVar, Union

import torch
from botorch.acquisition.monte_carlo import SampleReducingMCAcquisitionFunction
from botorch.acquisition.objective import (
ConstrainedMCObjective,
MCAcquisitionObjective,
PosteriorTransform,
)
from botorch.exceptions.errors import BotorchError
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.utils.safe_math import (
fatmax,
log_fatplus,
log_softplus,
logmeanexp,
smooth_amax,
)
from torch import Tensor

"""
NOTE: On the default temperature parameters:

tau_relu: It is generally important to set `tau_relu` to be very small, in particular,
smaller than the expected improvement value. Otherwise, the optimization can stagnate.
By setting `tau_relu=1e-6` by default, stagnation is exceedingly unlikely to occur due
to the smooth ReLU approximation for practical applications of BO.
IDEA: We could consider shrinking `tau_relu` with the progression of the optimization.

tau_max: This is only relevant for the batch (`q > 1`) case, and `tau_max=1e-2` is
sufficient to get a good approximation to the maximum improvement in the batch of
candidates. If `fat=False`, the smooth approximation to the maximum can saturate
numerically. It is therefore recommended to use `fat=True` when optimizing batches
of `q > 1` points.
"""
TAU_RELU = 1e-6
TAU_MAX = 1e-2
FloatOrTensor = TypeVar("FloatOrTensor", float, Tensor)


class LogImprovementMCAcquisitionFunction(SampleReducingMCAcquisitionFunction):
r"""
Abstract base class for Monte-Carlo-based batch LogEI acquisition functions.

:meta private:
"""

_log: bool = True

def __init__(
self,
model: Model,
sampler: Optional[MCSampler] = None,
objective: Optional[MCAcquisitionObjective] = None,
posterior_transform: Optional[PosteriorTransform] = None,
X_pending: Optional[Tensor] = None,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
eta: Union[Tensor, float] = 1e-3,
fat: bool = True,
tau_max: float = TAU_MAX,
) -> None:
r"""
Args:
model: A fitted model.
sampler: The sampler used to draw base samples. If not given,
a sampler is generated using `get_sampler`.
NOTE: For posteriors that do not support base samples,
a sampler compatible with intended use case must be provided.
See `ForkedRNGSampler` and `StochasticSampler` as examples.
objective: The MCAcquisitionObjective under which the samples are
evaluated. Defaults to `IdentityMCObjective()`.
posterior_transform: A PosteriorTransform (optional).
X_pending: A `batch_shape, m x d`-dim Tensor of `m` design points
that have points that have been submitted for function evaluation
but have not yet been evaluated.
constraints: A list of constraint callables which map a Tensor of posterior
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
are satisfied if `constraint(samples) < 0`.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. See the docs of
`compute_(log_)constraint_indicator` for more details on this parameter.
fat: Toggles the logarithmic / linear asymptotic behavior of the smooth
approximation to the ReLU.
tau_max: Temperature parameter controlling the sharpness of the
approximation to the `max` operator over the `q` candidate points.
"""
if isinstance(objective, ConstrainedMCObjective):
raise BotorchError(
"Log-Improvement should not be used with `ConstrainedMCObjective`."
"Please pass the `constraints` directly to the constructor of the "
"acquisition function."
)
q_reduction = partial(fatmax if fat else smooth_amax, tau=tau_max)
super().__init__(
model=model,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
sample_reduction=logmeanexp,
q_reduction=q_reduction,
constraints=constraints,
eta=eta,
fat=fat,
)
self.tau_max = tau_max


class qLogExpectedImprovement(LogImprovementMCAcquisitionFunction):
r"""MC-based batch Log Expected Improvement.

This computes qLogEI by
(1) sampling the joint posterior over q points,
(2) evaluating the smoothed log improvement over the current best for each sample,
(3) smoothly maximizing over q, and
(4) averaging over the samples in log space.

`qLogEI(X) ~ log(qEI(X)) = log(E(max(max Y - best_f, 0)))`,

where `Y ~ f(X)`, and `X = (x_1,...,x_q)`.

Example:
>>> model = SingleTaskGP(train_X, train_Y)
>>> best_f = train_Y.max()[0]
>>> sampler = SobolQMCNormalSampler(1024)
>>> qLogEI = qLogExpectedImprovement(model, best_f, sampler)
>>> qei = qLogEI(test_X)
"""

def __init__(
self,
model: Model,
best_f: Union[float, Tensor],
sampler: Optional[MCSampler] = None,
objective: Optional[MCAcquisitionObjective] = None,
posterior_transform: Optional[PosteriorTransform] = None,
X_pending: Optional[Tensor] = None,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
eta: Union[Tensor, float] = 1e-3,
fat: bool = True,
tau_max: float = TAU_MAX,
tau_relu: float = TAU_RELU,
) -> None:
r"""q-Log Expected Improvement.

Args:
model: A fitted model.
best_f: The best objective value observed so far (assumed noiseless). Can be
a `batch_shape`-shaped tensor, which in case of a batched model
specifies potentially different values for each element of the batch.
sampler: The sampler used to draw base samples. See `MCAcquisitionFunction`
more details.
objective: The MCAcquisitionObjective under which the samples are evaluated.
Defaults to `IdentityMCObjective()`.
posterior_transform: A PosteriorTransform (optional).
X_pending: A `m x d`-dim Tensor of `m` design points that have been
submitted for function evaluation but have not yet been evaluated.
Concatenated into `X` upon forward call. Copied and set to have no
gradient.
constraints: A list of constraint callables which map a Tensor of posterior
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
are satisfied if `constraint(samples) < 0`.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. See the docs of
`compute_(log_)smoothed_constraint_indicator` for details.
fat: Toggles the logarithmic / linear asymptotic behavior of the smooth
approximation to the ReLU.
tau_max: Temperature parameter controlling the sharpness of the smooth
approximations to max.
tau_relu: Temperature parameter controlling the sharpness of the smooth
approximations to ReLU.
"""
super().__init__(
model=model,
sampler=sampler,
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
constraints=constraints,
eta=eta,
tau_max=check_tau(tau_max, name="tau_max"),
fat=fat,
)
self.register_buffer("best_f", torch.as_tensor(best_f))
self.tau_relu = check_tau(tau_relu, name="tau_relu")

def _sample_forward(self, obj: Tensor) -> Tensor:
r"""Evaluate qLogExpectedImprovement on the candidate set `X`.

Args:
obj: `mc_shape x batch_shape x q`-dim Tensor of MC objective values.

Returns:
A `mc_shape x batch_shape x q`-dim Tensor of expected improvement values.
"""
li = _log_improvement(
Y=obj,
best_f=self.best_f,
tau=self.tau_relu,
fat=self._fat,
)
return li


"""
###################################### utils ##########################################
"""


def _log_improvement(
Y: Tensor,
best_f: Tensor,
tau: Union[float, Tensor],
fat: bool,
) -> Tensor:
"""Computes the logarithm of the softplus-smoothed improvement, i.e.
`log_softplus(Y - best_f, beta=(1 / tau))`.
Note that softplus is an approximation to the regular ReLU objective whose maximum
pointwise approximation error is linear with respect to tau as tau goes to zero.

Args:
obj: `mc_samples x batch_shape x q`-dim Tensor of output samples.
best_f: Best previously observed objective value(s), broadcastable with `obj`.
tau: Temperature parameter for smooth approximation of ReLU.
as `tau -> 0`, maximum pointwise approximation error is linear w.r.t. `tau`.
fat: Toggles the logarithmic / linear asymptotic behavior of the
smooth approximation to ReLU.

Returns:
A `mc_samples x batch_shape x q`-dim Tensor of improvement values.
"""
log_soft_clamp = log_fatplus if fat else log_softplus
Z = Y - best_f.to(Y)
return log_soft_clamp(Z, tau=tau) # ~ ((Y - best_f) / Y_std).clamp(0)


def check_tau(tau: FloatOrTensor, name: str) -> FloatOrTensor:
"""Checks the validity of the tau arguments of the functions below, and returns
`tau` if it is valid."""
if isinstance(tau, Tensor) and tau.numel() != 1:
raise ValueError(name + f" is not a scalar: {tau.numel() = }.")
if not (tau > 0):
raise ValueError(name + f" is non-positive: {tau = }.")
return tau
17 changes: 14 additions & 3 deletions botorch/acquisition/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ class SampleReducingMCAcquisitionFunction(MCAcquisitionFunction):
forward pass. These problems are circumvented by the design of this class.
"""

_log: bool = False # whether the acquisition utilities are in log-space

def __init__(
self,
model: Model,
Expand All @@ -181,6 +183,7 @@ def __init__(
q_reduction: SampleReductionProtocol = torch.amax,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
eta: Union[Tensor, float] = 1e-3,
fat: bool = False,
):
r"""Constructor of SampleReducingMCAcquisitionFunction.

Expand Down Expand Up @@ -216,6 +219,8 @@ def __init__(
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
fat: Wether to apply a fat-tailed smooth approximation to the feasibility
indicator or the canonical sigmoid approximation.
"""
if constraints is not None and isinstance(objective, ConstrainedMCObjective):
raise ValueError(
Expand All @@ -236,6 +241,7 @@ def __init__(
self._q_reduction = partial(q_reduction, dim=-1)
self._constraints = constraints
self._eta = eta
self._fat = fat

@concatenate_pending_points
@t_batch_mode_transform()
Expand Down Expand Up @@ -300,14 +306,19 @@ def _apply_constraints(self, acqval: Tensor, samples: Tensor) -> Tensor:
multiplied by a smoothed constraint indicator per sample.
"""
if self._constraints is not None:
if (acqval < 0).any():
if not self._log and (acqval < 0).any():
raise ValueError(
"Constraint-weighting requires unconstrained "
"acquisition values to be non-negative."
)
acqval = acqval * compute_smoothed_feasibility_indicator(
constraints=self._constraints, samples=samples, eta=self._eta
ind = compute_smoothed_feasibility_indicator(
constraints=self._constraints,
samples=samples,
eta=self._eta,
log=self._log,
fat=self._fat,
)
acqval = acqval.add(ind) if self._log else acqval.mul(ind)
return acqval


Expand Down
Loading