Skip to content

Commit 2f2b7e2

Browse files
esantorellafacebook-github-bot
authored andcommitted
Fix shape error in optimize_acqf_cyclic (#1648)
Summary: ## Motivation Fixes #873 In the past, `optimize_acqf` implicitly needed 3d inputs when there are equality constraints or inequality constraints and fixed_features don't provide the trivial solution, even though it worked with 2d inputs (no b-batches) in other cases. `optimize_acqf_cyclic` passed it 2d inputs, which would not generally work. I initially considered changing `optimize_acqf_cyclic` to pass 3d inputs, but since I found another place where 2d inputs were used, I decided to change `optimize_acqf` so it works with 2d inputs instead. This was not caught because the only usage of `optimize_acqf_cyclic` was in a test that mocked `optimize_acqf`, so `optimize_acqf_cyclic` was never actually run end-to-end. I changed the test for `optimize_acqf_cyclic` to be more end-to-end, at the cost of worse testing of some intermediate properties. We could keep both versions though. [x] Better docstring documentation on input shapes [x] Add a singleton leading b-dimension where initial conditions are 2d Pull Request resolved: #1648 Test Plan: [x] More end-to-end test of `optimize_acqf_cyclic` that doesn't stub in `optimize_acqf` (see above) [x] more input validation and unit tests for input validation [x] Ran cases that now raise errors without the new error handling, to make sure they were erroring before [x] Make `_make_linear_constraints` work with 2d inputs so that `optimize_acqf` also does (previously, optimize_acqf only worked in some cases) Reviewed By: Balandat Differential Revision: D42875942 Pulled By: esantorella fbshipit-source-id: e3c650683a6b8d7c9e36fe1f14558db2854bab56
1 parent ffcad4a commit 2f2b7e2

File tree

5 files changed

+197
-98
lines changed

5 files changed

+197
-98
lines changed

botorch/generation/gen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def gen_candidates_scipy(
5656
using `scipy.optimize.minimize` via a numpy converter.
5757
5858
Args:
59-
initial_conditions: Starting points for optimization.
59+
initial_conditions: Starting points for optimization, with shape
60+
(b) x q x d.
6061
acquisition_function: Acquisition function to be used.
6162
lower_bounds: Minimum values for each column of initial_conditions.
6263
upper_bounds: Maximum values for each column of initial_conditions.
@@ -162,7 +163,7 @@ def gen_candidates_scipy(
162163
X=initial_conditions, lower_bounds=lower_bounds, upper_bounds=upper_bounds
163164
)
164165
constraints = make_scipy_linear_constraints(
165-
shapeX=clamped_candidates.shape,
166+
shapeX=shapeX,
166167
inequality_constraints=inequality_constraints,
167168
equality_constraints=equality_constraints,
168169
)

botorch/optim/optimize.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def optimize_acqf(
110110
Returns:
111111
A two-element tuple containing
112112
113-
- a `(num_restarts) x q x d`-dim tensor of generated candidates.
113+
- A tensor of generated candidates. The shape is
114+
-- `q x d` if `return_best_only` is True (default)
115+
-- `num_restarts x q x d` if `return_best_only` is False
114116
- a tensor of associated acquisition values. If `sequential=False`,
115117
this is a `(num_restarts)`-dim tensor of joint acquisition values
116118
(with explicit restart dimension if `return_best_only=False`). If
@@ -158,6 +160,19 @@ def optimize_acqf(
158160
"initial conditions for the case of nonlinear inequality constraints."
159161
)
160162

163+
d = bounds.shape[1]
164+
if initial_conditions_provided:
165+
if batch_initial_conditions.ndim not in (2, 3):
166+
raise ValueError(
167+
"batch_initial_conditions must be 2-dimensional or 3-dimensional. "
168+
f"Its shape is {batch_initial_conditions.shape}."
169+
)
170+
if batch_initial_conditions.shape[-1] != d:
171+
raise ValueError(
172+
f"batch_initial_conditions.shape[-1] must be {d}. The "
173+
f"shape is {batch_initial_conditions.shape}."
174+
)
175+
161176
# Sets initial condition generator ic_gen if initial conditions not provided
162177
if not initial_conditions_provided:
163178
ic_gen = kwargs.pop("ic_generator", None)
@@ -298,7 +313,7 @@ def _optimize_batch_candidates(
298313
logger.info(f"Generated candidate batch {i+1} of {len(batched_ics)}.")
299314

300315
batch_candidates = torch.cat(batch_candidates_list)
301-
batch_acq_values = torch.cat(batch_acq_values_list)
316+
batch_acq_values = torch.stack(batch_acq_values_list).flatten()
302317
return batch_candidates, batch_acq_values, opt_warnings
303318

304319
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates(timeout_sec)

botorch/optim/parameter_constraints.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def make_scipy_linear_constraints(
7373
r"""Generate scipy constraints from torch representation.
7474
7575
Args:
76-
shapeX: The shape of the torch.Tensor to optimize over (i.e. `b x q x d`)
76+
shapeX: The shape of the torch.Tensor to optimize over (i.e. `(b) x q x d`)
7777
inequality constraints: A list of tuples (indices, coefficients, rhs),
7878
with each tuple encoding an inequality constraint of the form
7979
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`, where
@@ -219,10 +219,35 @@ def _make_linear_constraints(
219219
version of the input tensor `X`, returning a scalar.
220220
- "jac": A callable evaluating the constraint's Jacobian on `x`, a flattened
221221
version of the input tensor `X`, returning a numpy array.
222+
223+
>>> shapeX = torch.Size([3, 5, 4])
224+
>>> constraints = _make_linear_constraints(
225+
... indices=torch.tensor([1., 2.]),
226+
... coefficients=torch.tensor([-0.5, 1.3]),
227+
... rhs=0.49,
228+
... shapeX=shapeX,
229+
... eq=True
230+
... )
231+
>>> len(constraints)
232+
15
233+
>>> constraints[0].keys()
234+
dict_keys(['type', 'fun', 'jac'])
235+
>>> x = np.arange(60).reshape(shapeX)
236+
>>> constraints[0]["fun"](x)
237+
1.61 # 1 * -0.5 + 2 * 1.3 - 0.49
238+
>>> constraints[0]["jac"](x)
239+
[0., -0.5, 1.3, 0., 0., ...]
240+
>>> constraints[1]["fun"](x) #
241+
4.81
222242
"""
223-
if len(shapeX) != 3:
224-
raise UnsupportedError("`shapeX` must be `b x q x d`")
243+
if len(shapeX) not in (2, 3):
244+
raise UnsupportedError(
245+
f"`shapeX` must be `(b) x q x d` (at least two-dimensional). It is "
246+
f"{shapeX}."
247+
)
225248
q, d = shapeX[-2:]
249+
if len(shapeX) == 2:
250+
shapeX = torch.Size([1, q, d])
226251
n = shapeX.numel()
227252
constraints: List[ScipyConstraintDict] = []
228253
coeffs = _arrayify(coefficients)

test/optim/test_optimize.py

Lines changed: 80 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -334,24 +334,64 @@ def test_optimize_acqf_sequential_notimplemented(self):
334334
)
335335

336336
def test_optimize_acqf_runs_given_batch_initial_conditions(self):
337-
num_restarts, raw_samples, dim = 1, 1, 1
337+
num_restarts, raw_samples, dim = 1, 2, 3
338338

339339
opt_x = 2 / np.pi
340-
# start near one (of many) optima
341-
initial_conditions = (opt_x * 1.01) * torch.ones(
342-
(num_restarts, raw_samples, dim)
343-
)
340+
# -x[i] * 1 >= -opt_x * 1.01 => x[i] <= opt_x * 1.01
341+
inequality_constraints = [
342+
(torch.tensor([i]), -torch.tensor([1]), -opt_x * 1.01) for i in range(dim)
343+
] + [
344+
# x[i] * 1 >= opt_x * .99
345+
(torch.tensor([i]), torch.tensor([1]), opt_x * 0.99)
346+
for i in range(dim)
347+
]
348+
q = 1
349+
350+
ic_shapes = [(1, 2, dim), (2, 1, dim), (1, dim)]
351+
344352
torch.manual_seed(0)
345-
batch_candidates, acq_value_list = optimize_acqf(
346-
acq_function=SinOneOverXAcqusitionFunction(),
347-
bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]),
348-
q=1,
349-
num_restarts=num_restarts,
350-
raw_samples=raw_samples,
351-
batch_initial_conditions=initial_conditions,
352-
)
353-
self.assertAlmostEqual(batch_candidates.item(), opt_x, delta=1e-5)
354-
self.assertAlmostEqual(acq_value_list.item(), 1)
353+
for shape in ic_shapes:
354+
with self.subTest(shape=shape):
355+
# start near one (of many) optima
356+
initial_conditions = (opt_x * 1.01) * torch.ones(shape)
357+
batch_candidates, acq_value_list = optimize_acqf(
358+
acq_function=SinOneOverXAcqusitionFunction(),
359+
bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]),
360+
q=q,
361+
num_restarts=num_restarts,
362+
raw_samples=raw_samples,
363+
batch_initial_conditions=initial_conditions,
364+
inequality_constraints=inequality_constraints,
365+
)
366+
self.assertAllClose(
367+
batch_candidates,
368+
opt_x * torch.ones_like(batch_candidates),
369+
# must be at least 50% closer to the optimum than it started
370+
atol=0.004,
371+
rtol=0.005,
372+
)
373+
self.assertAlmostEqual(acq_value_list.item(), 1, places=3)
374+
375+
def test_optimize_acqf_wrong_ic_shape_inequality_constraints(self) -> None:
376+
dim = 3
377+
ic_shapes = [(1, 2, dim + 1), (1, 2, dim, 1), (1, dim + 1), (1, 1), (dim,)]
378+
379+
for shape in ic_shapes:
380+
with self.subTest(shape=shape):
381+
initial_conditions = torch.ones(shape)
382+
expected_error = (
383+
rf"batch_initial_conditions.shape\[-1\] must be {dim}\."
384+
if len(shape) in (2, 3)
385+
else r"batch_initial_conditions must be 2\-dimensional or "
386+
)
387+
with self.assertRaisesRegex(ValueError, expected_error):
388+
optimize_acqf(
389+
acq_function=MockAcquisitionFunction(),
390+
bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]),
391+
q=4,
392+
batch_initial_conditions=initial_conditions,
393+
num_restarts=1,
394+
)
355395

356396
def test_optimize_acqf_warns_on_opt_failure(self):
357397
"""
@@ -808,15 +848,20 @@ def test_optimize_acqf_cyclic(self, mock_optimize_acqf):
808848
tkwargs = {"device": self.device}
809849
bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)])
810850
inequality_constraints = [
811-
[torch.tensor([3]), torch.tensor([4]), torch.tensor(5)]
851+
[torch.tensor([2], dtype=int), torch.tensor([4.0]), torch.tensor(5.0)]
812852
]
813853
mock_acq_function = MockAcquisitionFunction()
814854
for q, dtype in itertools.product([1, 3], (torch.float, torch.double)):
815-
inequality_constraints[0] = [
816-
t.to(**tkwargs) for t in inequality_constraints[0]
855+
tkwargs["dtype"] = dtype
856+
inequality_constraints = [
857+
(
858+
# indices can't be floats or doubles
859+
inequality_constraints[0][0],
860+
inequality_constraints[0][1].to(**tkwargs),
861+
inequality_constraints[0][2].to(**tkwargs),
862+
)
817863
]
818864
mock_optimize_acqf.reset_mock()
819-
tkwargs["dtype"] = dtype
820865
bounds = bounds.to(**tkwargs)
821866
candidate_rvs = []
822867
acq_val_rvs = []
@@ -855,23 +900,23 @@ def test_optimize_acqf_cyclic(self, mock_optimize_acqf):
855900
post_processing_func=rounding_func,
856901
cyclic_options={"maxiter": num_cycles},
857902
)
858-
# check that X_pending is set correctly in cyclic optimization
859-
if q > 1:
860-
x_pending_call_args_list = mock_set_X_pending.call_args_list
861-
idxr = torch.ones(q, dtype=torch.bool, device=self.device)
862-
for i in range(len(x_pending_call_args_list) - 1):
863-
idxr[i] = 0
864-
self.assertTrue(
865-
torch.equal(
866-
x_pending_call_args_list[i][0][0], orig_candidates[idxr]
867-
)
903+
# check that X_pending is set correctly in cyclic optimization
904+
if q > 1:
905+
x_pending_call_args_list = mock_set_X_pending.call_args_list
906+
idxr = torch.ones(q, dtype=torch.bool, device=self.device)
907+
for i in range(len(x_pending_call_args_list) - 1):
908+
idxr[i] = 0
909+
self.assertTrue(
910+
torch.equal(
911+
x_pending_call_args_list[i][0][0], orig_candidates[idxr]
868912
)
869-
idxr[i] = 1
870-
orig_candidates[i] = candidate_rvs[i + 1]
871-
# check reset to base_X_pendingg
872-
self.assertIsNone(x_pending_call_args_list[-1][0][0])
873-
else:
874-
mock_set_X_pending.assert_not_called()
913+
)
914+
idxr[i] = 1
915+
orig_candidates[i] = candidate_rvs[i + 1]
916+
# check reset to base_X_pendingg
917+
self.assertIsNone(x_pending_call_args_list[-1][0][0])
918+
else:
919+
mock_set_X_pending.assert_not_called()
875920
# check final candidates
876921
expected_candidates = (
877922
torch.cat(candidate_rvs[-q:], dim=0) if q > 1 else candidate_rvs[0]

0 commit comments

Comments
 (0)