Skip to content

Commit bdb6b7d

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
compute_best_feasible_f utility (#1931)
Summary: Pull Request resolved: #1931 This commit separates out the `compute_best_feasible_f` function from `qNEI` as a utility function in order to use it in the `input_constructors` and `get_acquisition_function` (follow-up). Reviewed By: Balandat Differential Revision: D47365085 fbshipit-source-id: 2d8203544bd646638d9ea6c20789d0700f600c0b
1 parent a1b38fc commit bdb6b7d

File tree

5 files changed

+275
-42
lines changed

5 files changed

+275
-42
lines changed

botorch/acquisition/monte_carlo.py

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from typing import Any, Callable, List, Optional, Protocol, Tuple, Union
2828

2929
import torch
30-
from botorch import acquisition
3130
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
3231
from botorch.acquisition.cached_cholesky import CachedCholeskyMCAcquisitionFunction
3332
from botorch.acquisition.objective import (
@@ -36,7 +35,10 @@
3635
MCAcquisitionObjective,
3736
PosteriorTransform,
3837
)
39-
from botorch.acquisition.utils import prune_inferior_points
38+
from botorch.acquisition.utils import (
39+
compute_best_feasible_objective,
40+
prune_inferior_points,
41+
)
4042
from botorch.exceptions.errors import UnsupportedError
4143
from botorch.models.model import Model
4244
from botorch.sampling.base import MCSampler
@@ -591,46 +593,24 @@ def _get_samples_and_objectives(self, X: Tensor) -> Tuple[Tensor, Tensor]:
591593
return samples, obj
592594

593595
def _compute_best_feasible_objective(self, samples: Tensor, obj: Tensor) -> Tensor:
594-
"""
596+
r"""Computes best feasible objective value from samples.
597+
595598
Args:
596599
samples: `sample_shape x batch_shape x q x m`-dim posterior samples.
597600
obj: A `sample_shape x batch_shape x q`-dim Tensor of MC objective values.
598601
599602
Returns:
600603
A `sample_shape x batch_shape x 1`-dim Tensor of best feasible objectives.
601604
"""
602-
if self._constraints is not None:
603-
# is_feasible is sample_shape x batch_shape x q
604-
is_feasible = compute_smoothed_constraint_indicator(
605-
constraints=self._constraints, samples=samples, eta=self._eta
606-
)
607-
is_feasible = is_feasible > 0.5 # due to smooth approximation
608-
if is_feasible.any():
609-
obj = torch.where(is_feasible, obj, -torch.inf)
610-
else: # if there are no feasible observations, estimate a lower
611-
# bound on the objective by sampling convex combinations of X_baseline.
612-
convex_weights = torch.rand(
613-
32,
614-
self.X_baseline.shape[-2],
615-
dtype=self.X_baseline.dtype,
616-
device=self.X_baseline.device,
617-
)
618-
weights_sum = convex_weights.sum(dim=0, keepdim=True)
619-
convex_weights = convex_weights / weights_sum
620-
# infeasible cost M is such that -M < min_x f(x), thus
621-
# 0 < min_x f(x) - (-M), so we should take -M as a lower
622-
# bound on the best feasible objective
623-
return -acquisition.utils.get_infeasible_cost(
624-
X=convex_weights @ self.X_baseline,
625-
model=self.model,
626-
objective=self.objective,
627-
posterior_transform=self.posterior_transform,
628-
).expand(*obj.shape[:-1], 1)
629-
630-
# we don't need to differentiate through X_baseline for now, so taking
631-
# the regular max over the n points to get best_f is fine
632-
with torch.no_grad():
633-
return obj.amax(dim=-1, keepdim=True)
605+
return compute_best_feasible_objective(
606+
samples=samples,
607+
obj=obj,
608+
constraints=self._constraints,
609+
model=self.model,
610+
objective=self.objective,
611+
posterior_transform=self.posterior_transform,
612+
X_baseline=self.X_baseline,
613+
)
634614

635615

636616
class qProbabilityOfImprovement(SampleReducingMCAcquisitionFunction):

botorch/acquisition/utils.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
FastNondominatedPartitioning,
3333
NondominatedPartitioning,
3434
)
35+
from botorch.utils.objective import compute_feasibility_indicator
3536
from botorch.utils.sampling import optimize_posterior_samples
3637
from botorch.utils.transforms import is_fully_bayesian
3738
from torch import Tensor
@@ -213,6 +214,113 @@ def get_acquisition_function(
213214
)
214215

215216

217+
def compute_best_feasible_objective(
218+
samples: Tensor,
219+
obj: Tensor,
220+
constraints: Optional[List[Callable[[Tensor], Tensor]]],
221+
model: Optional[Model] = None,
222+
objective: Optional[MCAcquisitionObjective] = None,
223+
posterior_transform: Optional[PosteriorTransform] = None,
224+
X_baseline: Optional[Tensor] = None,
225+
infeasible_obj: Optional[Tensor] = None,
226+
) -> Tensor:
227+
"""Computes the largest `obj` value that is feasible under the `constraints`. If
228+
`constraints` is None, returns the best unconstrained objective value.
229+
230+
When no feasible observations exist and `infeasible_obj` is not `None`, returns
231+
`infeasible_obj` (potentially reshaped). When no feasible observations exist and
232+
`infeasible_obj` is `None`, uses `model`, `objective`, `posterior_transform`, and
233+
`X_baseline` to infer and return an `infeasible_obj` `M` s.t. `M < min_x f(x)`.
234+
235+
Args:
236+
samples: `(sample_shape) x batch_shape x q x m`-dim posterior samples.
237+
obj: A `(sample_shape) x batch_shape x q`-dim Tensor of MC objective values.
238+
constraints: A list of constraint callables which map posterior samples to
239+
a scalar. The associated constraint is considered satisfied if this
240+
scalar is less than zero.
241+
model: A Model, only required when there are no feasible observations.
242+
objective: An MCAcquisitionObjective, only optionally used when there are no
243+
feasible observations.
244+
posterior_transform: A PosteriorTransform, only optionally used when there are
245+
no feasible observations.
246+
X_baseline: A `batch_shape x d`-dim Tensor of baseline points, only required
247+
when there are no feasible observations.
248+
infeasible_obj: A Tensor to be returned when no feasible points exist.
249+
250+
Returns:
251+
A `(sample_shape) x batch_shape x 1`-dim Tensor of best feasible objectives.
252+
"""
253+
if constraints is None: # unconstrained case
254+
# we don't need to differentiate through X_baseline for now, so taking
255+
# the regular max over the n points to get best_f is fine
256+
with torch.no_grad():
257+
return obj.amax(dim=-1, keepdim=True)
258+
259+
is_feasible = compute_feasibility_indicator(
260+
constraints=constraints, samples=samples
261+
) # sample_shape x batch_shape x q
262+
if is_feasible.any():
263+
obj = torch.where(is_feasible, obj, -torch.inf)
264+
with torch.no_grad():
265+
return obj.amax(dim=-1, keepdim=True)
266+
267+
elif infeasible_obj is not None:
268+
return infeasible_obj.expand(*obj.shape[:-1], 1)
269+
270+
else:
271+
if model is None:
272+
raise ValueError(
273+
"Must specify `model` when no feasible observation exists."
274+
)
275+
if X_baseline is None:
276+
raise ValueError(
277+
"Must specify `X_baseline` when no feasible observation exists."
278+
)
279+
return _estimate_objective_lower_bound(
280+
model=model,
281+
objective=objective,
282+
posterior_transform=posterior_transform,
283+
X=X_baseline,
284+
).expand(*obj.shape[:-1], 1)
285+
286+
287+
def _estimate_objective_lower_bound(
288+
model: Model,
289+
objective: Optional[MCAcquisitionObjective],
290+
posterior_transform: Optional[PosteriorTransform],
291+
X: Tensor,
292+
) -> Tensor:
293+
"""Estimates a lower bound on the objective values by evaluating the model at convex
294+
combinations of `X`, returning the 6-sigma lower bound of the computed statistics.
295+
296+
Args:
297+
model: A fitted model.
298+
objective: An MCAcquisitionObjective with `m` outputs.
299+
posterior_transform: A PosteriorTransform.
300+
X: A `n x d`-dim Tensor of design points from which to draw convex combinations.
301+
302+
Returns:
303+
A `m`-dimensional Tensor of lower bounds of the objectives.
304+
"""
305+
convex_weights = torch.rand(
306+
32,
307+
X.shape[-2],
308+
dtype=X.dtype,
309+
device=X.device,
310+
)
311+
weights_sum = convex_weights.sum(dim=0, keepdim=True)
312+
convex_weights = convex_weights / weights_sum
313+
# infeasible cost M is such that -M < min_x f(x), thus
314+
# 0 < min_x f(x) - (-M), so we should take -M as a lower
315+
# bound on the best feasible objective
316+
return -get_infeasible_cost(
317+
X=convex_weights @ X,
318+
model=model,
319+
objective=objective,
320+
posterior_transform=posterior_transform,
321+
)
322+
323+
216324
def get_infeasible_cost(
217325
X: Tensor,
218326
model: Model,

botorch/utils/objective.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,37 @@ def apply_constraints_nonnegative_soft(
9595
return obj.clamp_min(0).mul(w) # Enforce non-negativity of obj, apply constraints.
9696

9797

98+
def compute_feasibility_indicator(
99+
constraints: Optional[List[Callable[[Tensor], Tensor]]],
100+
samples: Tensor,
101+
) -> Tensor:
102+
r"""Computes the feasibility of a list of constraints given posterior samples.
103+
104+
Args:
105+
constraints: A list of callables, each mapping a batch_shape x q x m`-dim Tensor
106+
to a `batch_shape x q`-dim Tensor, where negative values imply feasibility.
107+
samples: A batch_shape x q x m`-dim Tensor of posterior samples.
108+
109+
Returns:
110+
A `batch_shape x q`-dim tensor of Boolean feasibility values.
111+
"""
112+
ind = torch.ones(samples.shape[:-1], dtype=torch.bool, device=samples.device)
113+
if constraints is not None:
114+
for constraint in constraints:
115+
ind = ind.logical_and(constraint(samples) < 0)
116+
return ind
117+
118+
98119
def compute_smoothed_constraint_indicator(
99120
constraints: List[Callable[[Tensor], Tensor]],
100121
samples: Tensor,
101122
eta: Union[Tensor, float],
102123
) -> Tensor:
103-
r"""Computes the feasibility indicator of a list of constraints given posterior
104-
samples, using a sigmoid to smoothly approximate the feasibility indicator
105-
of each individual constraint to ensure differentiability and high gradient signal.
124+
r"""Computes the smoothed feasibility indicator of a list of constraints.
125+
126+
Given posterior samples, using a sigmoid to smoothly approximate the feasibility
127+
indicator of each individual constraint to ensure differentiability and high
128+
gradient signal.
106129
107130
Args:
108131
constraints: A list of callables, each mapping a Tensor of size `b x q x m`

test/acquisition/test_utils.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ScalarizedPosteriorTransform,
2020
)
2121
from botorch.acquisition.utils import (
22+
compute_best_feasible_objective,
2223
expand_trace_observations,
2324
get_acquisition_function,
2425
get_infeasible_cost,
@@ -575,7 +576,83 @@ def test_GetUnknownAcquisitionFunction(self):
575576
)
576577

577578

578-
class TestGetInfeasibleCost(BotorchTestCase):
579+
class TestConstraintUtils(BotorchTestCase):
580+
def test_compute_best_feasible_objective(self):
581+
for dtype in (torch.float, torch.double):
582+
with self.subTest(dtype=dtype):
583+
tkwargs = {"dtype": dtype, "device": self.device}
584+
n = 5
585+
X = torch.arange(n, **tkwargs).view(-1, 1)
586+
means = torch.arange(n, **tkwargs).view(-1, 1)
587+
samples = means
588+
variances = torch.tensor(
589+
[0.09, 0.25, 0.36, 0.25, 0.09], **tkwargs
590+
).view(-1, 1)
591+
mm = MockModel(
592+
MockPosterior(mean=means, variance=variances, samples=samples)
593+
)
594+
595+
# testing all feasible points
596+
obj = means.squeeze(-1)
597+
constraints = [lambda samples: -torch.ones_like(samples[..., 0])]
598+
best_f = compute_best_feasible_objective(
599+
samples=means, obj=obj, constraints=constraints
600+
)
601+
self.assertAllClose(best_f, obj.amax(dim=-1, keepdim=True))
602+
603+
# testing with some infeasible points
604+
con_cutoff = 3.0
605+
best_f = compute_best_feasible_objective(
606+
samples=means,
607+
obj=obj,
608+
constraints=[
609+
lambda samples: samples[..., 0] - (con_cutoff + 1 / 2)
610+
],
611+
)
612+
# only first three points are feasible
613+
self.assertAllClose(best_f, torch.tensor([con_cutoff], **tkwargs))
614+
615+
# testing with no feasible points and infeasible obj
616+
infeasible_obj = torch.tensor(torch.pi, **tkwargs)
617+
best_f = compute_best_feasible_objective(
618+
samples=means,
619+
obj=obj,
620+
constraints=[lambda X: torch.ones_like(X[..., 0])],
621+
infeasible_obj=infeasible_obj,
622+
)
623+
self.assertAllClose(best_f, infeasible_obj.unsqueeze(0))
624+
625+
# testing with no feasible points and not infeasible obj
626+
def objective(Y, X):
627+
return Y.squeeze(-1) - 5.0
628+
629+
best_f = compute_best_feasible_objective(
630+
samples=means,
631+
obj=obj,
632+
constraints=[lambda X: torch.ones_like(X[..., 0])],
633+
model=mm,
634+
X_baseline=X,
635+
objective=objective,
636+
)
637+
self.assertAllClose(
638+
best_f, -get_infeasible_cost(X=X, model=mm, objective=objective)
639+
)
640+
641+
with self.assertRaisesRegex(ValueError, "Must specify `model`"):
642+
best_f = compute_best_feasible_objective(
643+
samples=means,
644+
obj=obj,
645+
constraints=[lambda X: torch.ones_like(X[..., 0])],
646+
X_baseline=X,
647+
)
648+
with self.assertRaisesRegex(ValueError, "Must specify `X_baseline`"):
649+
best_f = compute_best_feasible_objective(
650+
samples=means,
651+
obj=obj,
652+
constraints=[lambda X: torch.ones_like(X[..., 0])],
653+
model=mm,
654+
)
655+
579656
def test_get_infeasible_cost(self):
580657
for dtype in (torch.float, torch.double):
581658
tkwargs = {"dtype": dtype, "device": self.device}

0 commit comments

Comments
 (0)