Skip to content

Commit 10e7209

Browse files
committed
fea: allow both batch_initial_conditions and random sampling together.
1 parent 24f659c commit 10e7209

File tree

3 files changed

+134
-24
lines changed

3 files changed

+134
-24
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1779,7 +1779,7 @@ def optimize_objective(
17791779
bounds=free_feature_bounds,
17801780
q=q,
17811781
num_restarts=optimizer_options.get("num_restarts", 60),
1782-
raw_samples=optimizer_options.get("raw_samples", 1024),
1782+
raw_samples=optimizer_options.get("raw_samples", 1024), # NOTE potential behaviour change
17831783
options={
17841784
"batch_limit": optimizer_options.get("batch_limit", 8),
17851785
"maxiter": optimizer_options.get("maxiter", 200),

botorch/optim/optimize.py

Lines changed: 94 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,19 @@ def __post_init__(self) -> None:
114114
f"shape is {batch_initial_conditions_shape}."
115115
)
116116

117+
if (
118+
self.raw_samples is not None
119+
and (self.raw_samples - batch_initial_conditions_shape[-2]) > 0
120+
and len(batch_initial_conditions_shape) == 3
121+
and self.num_restarts is not None
122+
and self.num_restarts != batch_initial_conditions_shape[0]
123+
):
124+
raise ValueError(
125+
"If using `batch_initial_conditions` together with `raw_samples`, "
126+
"the first repeat dimension of `batch_initial_conditions` must "
127+
"match `num_restarts`."
128+
)
129+
117130
elif self.ic_generator is None:
118131
if self.nonlinear_inequality_constraints is not None:
119132
raise RuntimeError(
@@ -253,22 +266,44 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor
253266

254267
initial_conditions_provided = opt_inputs.batch_initial_conditions is not None
255268

269+
required_raw_samples = opt_inputs.raw_samples
256270
if initial_conditions_provided:
257-
batch_initial_conditions = opt_inputs.batch_initial_conditions
271+
provided_initial_conditions = opt_inputs.batch_initial_conditions
272+
if opt_inputs.raw_samples is not None:
273+
required_raw_samples -= provided_initial_conditions.shape[-2]
258274
else:
275+
provided_initial_conditions = None
276+
277+
if required_raw_samples is not None and required_raw_samples > 0:
259278
# pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
260-
batch_initial_conditions = opt_inputs.get_ic_generator()(
279+
generated_initial_conditions = opt_inputs.get_ic_generator()(
261280
acq_function=opt_inputs.acq_function,
262281
bounds=opt_inputs.bounds,
263282
q=opt_inputs.q,
264283
num_restarts=opt_inputs.num_restarts,
265-
raw_samples=opt_inputs.raw_samples,
284+
raw_samples=required_raw_samples,
266285
fixed_features=opt_inputs.fixed_features,
267286
options=options,
268287
inequality_constraints=opt_inputs.inequality_constraints,
269288
equality_constraints=opt_inputs.equality_constraints,
270289
**opt_inputs.ic_gen_kwargs,
271290
)
291+
else:
292+
generated_initial_conditions = None
293+
294+
if provided_initial_conditions is not None and generated_initial_conditions is not None:
295+
provided_initial_conditions = provided_initial_conditions.repeat(
296+
opt_inputs.num_restarts, *([1] * (provided_initial_conditions.dim()-1))
297+
)
298+
batch_initial_conditions = torch.cat(
299+
[provided_initial_conditions, generated_initial_conditions], dim=-2
300+
) # should this be shuffled?
301+
elif provided_initial_conditions is not None:
302+
batch_initial_conditions = provided_initial_conditions
303+
elif generated_initial_conditions is not None:
304+
batch_initial_conditions = generated_initial_conditions
305+
else:
306+
raise ValueError("Either `batch_initial_conditions` or `raw_samples` must be set.")
272307

273308
batch_limit: int = options.get(
274309
"batch_limit",
@@ -339,31 +374,39 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
339374
first_warn_msg = (
340375
"Optimization failed in `gen_candidates_scipy` with the following "
341376
f"warning(s):\n{[w.message for w in ws]}\nBecause you specified "
342-
"`batch_initial_conditions`, optimization will not be retried with "
343-
"new initial conditions and will proceed with the current solution."
344-
" Suggested remediation: Try again with different "
345-
"`batch_initial_conditions`, or don't provide `batch_initial_conditions.`"
346-
if initial_conditions_provided
377+
"`batch_initial_conditions`>`raw_samples`, optimization will not "
378+
"be retried with new initial conditions and will proceed with the "
379+
"current solution. Suggested remediation: Try again with different "
380+
"`batch_initial_conditions`, don't provide `batch_initial_conditions, "
381+
"or increase `raw_samples`.`"
382+
if required_raw_samples is not None and required_raw_samples > 0
347383
else "Optimization failed in `gen_candidates_scipy` with the following "
348384
f"warning(s):\n{[w.message for w in ws]}\nTrying again with a new "
349385
"set of initial conditions."
350386
)
351387
warnings.warn(first_warn_msg, RuntimeWarning, stacklevel=2)
352388

353-
if not initial_conditions_provided:
354-
batch_initial_conditions = opt_inputs.get_ic_generator()(
389+
if required_raw_samples is not None and required_raw_samples > 0:
390+
generated_initial_conditions = opt_inputs.get_ic_generator()(
355391
acq_function=opt_inputs.acq_function,
356392
bounds=opt_inputs.bounds,
357393
q=opt_inputs.q,
358394
num_restarts=opt_inputs.num_restarts,
359-
raw_samples=opt_inputs.raw_samples,
395+
raw_samples=required_raw_samples,
360396
fixed_features=opt_inputs.fixed_features,
361397
options=options,
362398
inequality_constraints=opt_inputs.inequality_constraints,
363399
equality_constraints=opt_inputs.equality_constraints,
364400
**opt_inputs.ic_gen_kwargs,
365401
)
366402

403+
if provided_initial_conditions is not None:
404+
batch_initial_conditions = torch.cat(
405+
[provided_initial_conditions, generated_initial_conditions], dim=-2
406+
) # should this be shuffled?
407+
else:
408+
batch_initial_conditions = generated_initial_conditions
409+
367410
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()
368411

369412
optimization_warning_raised = any(
@@ -1199,11 +1242,46 @@ def optimize_acqf_discrete_local_search(
11991242
inequality_constraints = inequality_constraints or []
12001243
for i in range(q):
12011244
# generate some starting points
1202-
if i == 0 and batch_initial_conditions is not None:
1203-
X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid)
1204-
X0 = _filter_infeasible(
1205-
X=X0, inequality_constraints=inequality_constraints
1206-
).unsqueeze(1)
1245+
if i == 0:
1246+
1247+
if batch_initial_conditions is not None:
1248+
provided_X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid)
1249+
provided_X0 = _filter_infeasible(
1250+
X=provided_X0, inequality_constraints=inequality_constraints
1251+
).unsqueeze(1)
1252+
if raw_samples is not None:
1253+
required_raw_samples = raw_samples - batch_initial_conditions.shape[-2]
1254+
else:
1255+
required_raw_samples = raw_samples
1256+
provided_X0 = None
1257+
1258+
if required_raw_samples > 0:
1259+
X_init = _gen_batch_initial_conditions_local_search(
1260+
discrete_choices=discrete_choices,
1261+
raw_samples=required_raw_samples,
1262+
X_avoid=X_avoid,
1263+
inequality_constraints=inequality_constraints,
1264+
min_points=num_restarts,
1265+
)
1266+
# pick the best starting points
1267+
with torch.no_grad():
1268+
acqvals_init = _split_batch_eval_acqf(
1269+
acq_function=acq_function,
1270+
X=X_init.unsqueeze(1),
1271+
max_batch_size=max_batch_size,
1272+
).unsqueeze(-1)
1273+
generated_X0 = X_init[acqvals_init.topk(k=num_restarts, largest=True, dim=0).indices]
1274+
1275+
if provided_X0 is not None and generated_X0 is not None:
1276+
provided_X0 = provided_X0.repeat(num_restarts, *([1] * (provided_X0.ndim - 1)))
1277+
X0 = torch.cat([provided_X0, generated_X0], dim=-2)
1278+
elif provided_X0 is not None:
1279+
X0 = provided_X0
1280+
elif generated_X0 is not None:
1281+
X0 = generated_X0
1282+
else:
1283+
raise ValueError("Either `batch_initial_conditions` or `raw_samples` must be set.")
1284+
12071285
else:
12081286
X_init = _gen_batch_initial_conditions_local_search(
12091287
discrete_choices=discrete_choices,

test/optim/test_optimize.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,26 @@ def test_optimize_acqf_joint(
167167
cnt += 1
168168
self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)
169169

170-
# test generation with provided initial conditions
170+
# test generation with provided initial conditions less than raw_samples
171+
candidates, acq_vals = optimize_acqf(
172+
acq_function=mock_acq_function,
173+
bounds=bounds,
174+
q=q,
175+
num_restarts=num_restarts,
176+
raw_samples=3,
177+
options=options,
178+
return_best_only=False,
179+
batch_initial_conditions=torch.zeros(
180+
num_restarts, q, 3, device=self.device, dtype=dtype
181+
),
182+
gen_candidates=mock_gen_candidates,
183+
)
184+
self.assertTrue(torch.equal(candidates, mock_candidates))
185+
self.assertTrue(torch.equal(acq_vals, mock_acq_values))
186+
cnt += 1
187+
self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)
188+
189+
# test generation with provided initial conditions greater than raw_samples
171190
candidates, acq_vals = optimize_acqf(
172191
acq_function=mock_acq_function,
173192
bounds=bounds,
@@ -543,7 +562,15 @@ def test_optimize_acqf_batch_limit(self) -> None:
543562
gen_candidates=gen_candidates,
544563
batch_initial_conditions=ics,
545564
)
546-
expected_shape = (num_restarts,) if ics is None else (ics.shape[0],)
565+
expected_shape = (
566+
(num_restarts,)
567+
if ics is None
568+
else (
569+
(ics.shape[0],)
570+
if ics.shape[0] > raw_samples
571+
else (ics.shape[0]*num_restarts,)
572+
)
573+
)
547574
self.assertEqual(acq_value_list.shape, expected_shape)
548575

549576
def test_optimize_acqf_runs_given_batch_initial_conditions(self):
@@ -635,11 +662,12 @@ def test_optimize_acqf_warns_on_opt_failure(self):
635662
"Optimization failed in `gen_candidates_scipy` with the following "
636663
"warning(s):\n[OptimizationWarning('Optimization failed within "
637664
"`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION"
638-
"_IN_LNSRCH.')]\nBecause you specified `batch_initial_conditions`, "
639-
"optimization will not be retried with new initial conditions and will "
640-
"proceed with the current solution. Suggested remediation: Try again with "
641-
"different `batch_initial_conditions`, or don't provide "
642-
"`batch_initial_conditions.`"
665+
"_IN_LNSRCH.')]\nBecause you specified "
666+
"`batch_initial_conditions`>`raw_samples`, optimization will not "
667+
"be retried with new initial conditions and will proceed with the "
668+
"current solution. Suggested remediation: Try again with different "
669+
"`batch_initial_conditions`, don't provide `batch_initial_conditions, "
670+
"or increase `raw_samples`.`"
643671
)
644672
expected_warning_raised = any(
645673
issubclass(w.category, RuntimeWarning) and message in str(w.message)
@@ -1841,3 +1869,7 @@ def my_gen():
18411869
)
18421870
ic_generator = opt_inputs.get_ic_generator()
18431871
self.assertIs(ic_generator, my_gen)
1872+
1873+
if __name__ == "__main__":
1874+
import pytest
1875+
pytest.main([__file__])

0 commit comments

Comments
 (0)