Skip to content

Commit b6f8b85

Browse files
committed
reset: squash into single commit
1 parent 5d37606 commit b6f8b85

File tree

3 files changed

+296
-62
lines changed

3 files changed

+296
-62
lines changed

botorch/optim/optimize.py

Lines changed: 165 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,41 @@ def __post_init__(self) -> None:
109109
"3-dimensional. Its shape is "
110110
f"{batch_initial_conditions_shape}."
111111
)
112+
112113
if batch_initial_conditions_shape[-1] != d:
113114
raise ValueError(
114115
f"batch_initial_conditions.shape[-1] must be {d}. The "
115116
f"shape is {batch_initial_conditions_shape}."
116117
)
117118

119+
if len(batch_initial_conditions_shape) == 2:
120+
warnings.warn(
121+
"If using a 2-dim `batch_initial_conditions` botorch will "
122+
"default to old behavior of ignoring `num_restarts` and just "
123+
"use the given `batch_initial_conditions` by setting "
124+
"`raw_samples` to None.",
125+
RuntimeWarning,
126+
)
127+
# Use object.__setattr__ to bypass immutability and set a value
128+
object.__setattr__(self, "raw_samples", None)
129+
130+
if (
131+
len(batch_initial_conditions_shape) == 3
132+
and batch_initial_conditions_shape[0] < self.num_restarts
133+
and batch_initial_conditions_shape[-2] != self.q
134+
):
135+
warnings.warn(
136+
"If using a 3-dim `batch_initial_conditions` where the "
137+
"first dimension is less than `num_restarts` and the second "
138+
"dimension is not equal to `q`, botorch will default to "
139+
"old behavior of ignoring `num_restarts` and just use the "
140+
"given `batch_initial_conditions` by setting `raw_samples` "
141+
"to None.",
142+
RuntimeWarning,
143+
)
144+
# Use object.__setattr__ to bypass immutability and set a value
145+
object.__setattr__(self, "raw_samples", None)
146+
118147
elif self.ic_generator is None:
119148
if self.nonlinear_inequality_constraints is not None:
120149
raise RuntimeError(
@@ -126,6 +155,7 @@ def __post_init__(self) -> None:
126155
"Must specify `raw_samples` when "
127156
"`batch_initial_conditions` is None`."
128157
)
158+
129159
if self.fixed_features is not None and any(
130160
(k < 0 for k in self.fixed_features)
131161
):
@@ -248,25 +278,54 @@ def _optimize_acqf_sequential_q(
248278
if base_X_pending is not None
249279
else candidates
250280
)
251-
logger.info(f"Generated sequential candidate {i+1} of {opt_inputs.q}")
281+
logger.info(f"Generated sequential candidate {i + 1} of {opt_inputs.q}")
252282
opt_inputs.acq_function.set_X_pending(base_X_pending)
253283
return candidates, torch.stack(acq_value_list)
254284

255285

286+
def _combine_initial_conditions(
287+
provided_initial_conditions: Tensor | None = None,
288+
generated_initial_conditions: Tensor | None = None,
289+
dim=0,
290+
) -> Tensor:
291+
if (
292+
provided_initial_conditions is not None
293+
and generated_initial_conditions is not None
294+
):
295+
return torch.cat(
296+
[provided_initial_conditions, generated_initial_conditions], dim=dim
297+
)
298+
elif provided_initial_conditions is not None:
299+
return provided_initial_conditions
300+
elif generated_initial_conditions is not None:
301+
return generated_initial_conditions
302+
else:
303+
raise ValueError(
304+
"Either `batch_initial_conditions` or `raw_samples` must be set."
305+
)
306+
307+
256308
def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor]:
257309
options = opt_inputs.options or {}
258310

259-
initial_conditions_provided = opt_inputs.batch_initial_conditions is not None
311+
required_num_restarts = opt_inputs.num_restarts
312+
provided_initial_conditions = opt_inputs.batch_initial_conditions
313+
generated_initial_conditions = None
260314

261-
if initial_conditions_provided:
262-
batch_initial_conditions = opt_inputs.batch_initial_conditions
263-
else:
264-
# pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
265-
batch_initial_conditions = opt_inputs.get_ic_generator()(
315+
if (
316+
provided_initial_conditions is not None
317+
and len(provided_initial_conditions.shape) == 3
318+
):
319+
required_num_restarts -= provided_initial_conditions.shape[0]
320+
321+
if opt_inputs.raw_samples is not None and required_num_restarts > 0:
322+
# pyre-ignore[28]: Unexpected keyword argument `acq_function`
323+
# to anonymous call.
324+
generated_initial_conditions = opt_inputs.get_ic_generator()(
266325
acq_function=opt_inputs.acq_function,
267326
bounds=opt_inputs.bounds,
268327
q=opt_inputs.q,
269-
num_restarts=opt_inputs.num_restarts,
328+
num_restarts=required_num_restarts,
270329
raw_samples=opt_inputs.raw_samples,
271330
fixed_features=opt_inputs.fixed_features,
272331
options=options,
@@ -275,6 +334,11 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor
275334
**opt_inputs.ic_gen_kwargs,
276335
)
277336

337+
batch_initial_conditions = _combine_initial_conditions(
338+
provided_initial_conditions=provided_initial_conditions,
339+
generated_initial_conditions=generated_initial_conditions,
340+
)
341+
278342
batch_limit: int = options.get(
279343
"batch_limit",
280344
(
@@ -325,7 +389,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
325389
opt_warnings += ws
326390
batch_candidates_list.append(batch_candidates_curr)
327391
batch_acq_values_list.append(batch_acq_values_curr)
328-
logger.info(f"Generated candidate batch {i+1} of {len(batched_ics)}.")
392+
logger.info(f"Generated candidate batch {i + 1} of {len(batched_ics)}.")
329393

330394
batch_candidates = torch.cat(batch_candidates_list)
331395
has_scalars = batch_acq_values_list[0].ndim == 0
@@ -344,23 +408,24 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
344408
first_warn_msg = (
345409
"Optimization failed in `gen_candidates_scipy` with the following "
346410
f"warning(s):\n{[w.message for w in ws]}\nBecause you specified "
347-
"`batch_initial_conditions`, optimization will not be retried with "
348-
"new initial conditions and will proceed with the current solution."
349-
" Suggested remediation: Try again with different "
350-
"`batch_initial_conditions`, or don't provide `batch_initial_conditions.`"
351-
if initial_conditions_provided
411+
"`batch_initial_conditions` larger than required `num_restarts`, "
412+
"optimization will not be retried with new initial conditions and "
413+
"will proceed with the current solution. Suggested remediation: "
414+
"Try again with different `batch_initial_conditions`, don't provide "
415+
"`batch_initial_conditions`, or increase `num_restarts`."
416+
if batch_initial_conditions is not None and required_num_restarts <= 0
352417
else "Optimization failed in `gen_candidates_scipy` with the following "
353418
f"warning(s):\n{[w.message for w in ws]}\nTrying again with a new "
354419
"set of initial conditions."
355420
)
356421
warnings.warn(first_warn_msg, RuntimeWarning, stacklevel=2)
357422

358-
if not initial_conditions_provided:
359-
batch_initial_conditions = opt_inputs.get_ic_generator()(
423+
if opt_inputs.raw_samples is not None and required_num_restarts > 0:
424+
generated_initial_conditions = opt_inputs.get_ic_generator()(
360425
acq_function=opt_inputs.acq_function,
361426
bounds=opt_inputs.bounds,
362427
q=opt_inputs.q,
363-
num_restarts=opt_inputs.num_restarts,
428+
num_restarts=required_num_restarts,
364429
raw_samples=opt_inputs.raw_samples,
365430
fixed_features=opt_inputs.fixed_features,
366431
options=options,
@@ -369,6 +434,11 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
369434
**opt_inputs.ic_gen_kwargs,
370435
)
371436

437+
batch_initial_conditions = _combine_initial_conditions(
438+
provided_initial_conditions=provided_initial_conditions,
439+
generated_initial_conditions=generated_initial_conditions,
440+
)
441+
372442
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()
373443

374444
optimization_warning_raised = any(
@@ -1177,7 +1247,7 @@ def _gen_batch_initial_conditions_local_search(
11771247
inequality_constraints: list[tuple[Tensor, Tensor, float]],
11781248
min_points: int,
11791249
max_tries: int = 100,
1180-
):
1250+
) -> Tensor:
11811251
"""Generate initial conditions for local search."""
11821252
device = discrete_choices[0].device
11831253
dtype = discrete_choices[0].dtype
@@ -1197,6 +1267,58 @@ def _gen_batch_initial_conditions_local_search(
11971267
raise RuntimeError(f"Failed to generate at least {min_points} initial conditions")
11981268

11991269

1270+
def _gen_starting_points_local_search(
1271+
discrete_choices: list[Tensor],
1272+
raw_samples: int,
1273+
batch_initial_conditions: Tensor,
1274+
X_avoid: Tensor,
1275+
inequality_constraints: list[tuple[Tensor, Tensor, float]],
1276+
min_points: int,
1277+
acq_function: AcquisitionFunction,
1278+
max_batch_size: int = 2048,
1279+
max_tries: int = 100,
1280+
) -> Tensor:
1281+
required_min_points = min_points
1282+
provided_X0 = None
1283+
generated_X0 = None
1284+
1285+
if batch_initial_conditions is not None:
1286+
provided_X0 = _filter_invalid(
1287+
X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid
1288+
)
1289+
provided_X0 = _filter_infeasible(
1290+
X=provided_X0, inequality_constraints=inequality_constraints
1291+
).unsqueeze(1)
1292+
required_min_points -= batch_initial_conditions.shape[0]
1293+
1294+
if required_min_points > 0:
1295+
generated_X0 = _gen_batch_initial_conditions_local_search(
1296+
discrete_choices=discrete_choices,
1297+
raw_samples=raw_samples,
1298+
X_avoid=X_avoid,
1299+
inequality_constraints=inequality_constraints,
1300+
min_points=min_points,
1301+
max_tries=max_tries,
1302+
)
1303+
1304+
# pick the best starting points
1305+
with torch.no_grad():
1306+
acqvals_init = _split_batch_eval_acqf(
1307+
acq_function=acq_function,
1308+
X=generated_X0.unsqueeze(1),
1309+
max_batch_size=max_batch_size,
1310+
).unsqueeze(-1)
1311+
1312+
generated_X0 = generated_X0[
1313+
acqvals_init.topk(k=min_points, largest=True, dim=0).indices
1314+
]
1315+
1316+
return _combine_initial_conditions(
1317+
provided_initial_conditions=provided_X0 if provided_X0 is not None else None,
1318+
generated_initial_conditions=generated_X0 if generated_X0 is not None else None,
1319+
)
1320+
1321+
12001322
def optimize_acqf_discrete_local_search(
12011323
acq_function: AcquisitionFunction,
12021324
discrete_choices: list[Tensor],
@@ -1207,6 +1329,7 @@ def optimize_acqf_discrete_local_search(
12071329
X_avoid: Tensor | None = None,
12081330
batch_initial_conditions: Tensor | None = None,
12091331
max_batch_size: int = 2048,
1332+
max_tries: int = 100,
12101333
unique: bool = True,
12111334
) -> tuple[Tensor, Tensor]:
12121335
r"""Optimize acquisition function over a lattice.
@@ -1238,6 +1361,8 @@ def optimize_acqf_discrete_local_search(
12381361
max_batch_size: The maximum number of choices to evaluate in batch.
12391362
A large limit can cause excessive memory usage if the model has
12401363
a large training set.
1364+
max_tries: Maximum number of iterations to try when generating initial
1365+
conditions.
12411366
unique: If True return unique choices, o/w choices may be repeated
12421367
(only relevant if `q > 1`).
12431368
@@ -1247,6 +1372,16 @@ def optimize_acqf_discrete_local_search(
12471372
- a `q x d`-dim tensor of generated candidates.
12481373
- an associated acquisition value.
12491374
"""
1375+
if batch_initial_conditions is not None:
1376+
if not (
1377+
len(batch_initial_conditions.shape) == 3
1378+
and batch_initial_conditions.shape[-2] == 1
1379+
):
1380+
raise ValueError(
1381+
"batch_initial_conditions must have shape `n x 1 x d` if "
1382+
f"given (recieved {batch_initial_conditions})."
1383+
)
1384+
12501385
candidate_list = []
12511386
base_X_pending = acq_function.X_pending if q > 1 else None
12521387
base_X_avoid = X_avoid
@@ -1259,27 +1394,18 @@ def optimize_acqf_discrete_local_search(
12591394
inequality_constraints = inequality_constraints or []
12601395
for i in range(q):
12611396
# generate some starting points
1262-
if i == 0 and batch_initial_conditions is not None:
1263-
X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid)
1264-
X0 = _filter_infeasible(
1265-
X=X0, inequality_constraints=inequality_constraints
1266-
).unsqueeze(1)
1267-
else:
1268-
X_init = _gen_batch_initial_conditions_local_search(
1269-
discrete_choices=discrete_choices,
1270-
raw_samples=raw_samples,
1271-
X_avoid=X_avoid,
1272-
inequality_constraints=inequality_constraints,
1273-
min_points=num_restarts,
1274-
)
1275-
# pick the best starting points
1276-
with torch.no_grad():
1277-
acqvals_init = _split_batch_eval_acqf(
1278-
acq_function=acq_function,
1279-
X=X_init.unsqueeze(1),
1280-
max_batch_size=max_batch_size,
1281-
).unsqueeze(-1)
1282-
X0 = X_init[acqvals_init.topk(k=num_restarts, largest=True, dim=0).indices]
1397+
X0 = _gen_starting_points_local_search(
1398+
discrete_choices=discrete_choices,
1399+
raw_samples=raw_samples,
1400+
batch_initial_conditions=batch_initial_conditions,
1401+
X_avoid=X_avoid,
1402+
inequality_constraints=inequality_constraints,
1403+
min_points=num_restarts,
1404+
acq_function=acq_function,
1405+
max_batch_size=max_batch_size,
1406+
max_tries=max_tries,
1407+
)
1408+
batch_initial_conditions = None
12831409

12841410
# optimize from the best starting points
12851411
best_xs = torch.zeros(len(X0), dim, device=device, dtype=dtype)

botorch/optim/optimize_homotopy.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def optimize_acqf_homotopy(
157157
"""
158158
shared_optimize_acqf_kwargs = {
159159
"num_restarts": num_restarts,
160-
"raw_samples": raw_samples,
161160
"inequality_constraints": inequality_constraints,
162161
"equality_constraints": equality_constraints,
163162
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
@@ -178,6 +177,7 @@ def optimize_acqf_homotopy(
178177

179178
for _ in range(q):
180179
candidates = batch_initial_conditions
180+
q_raw_samples = raw_samples
181181
homotopy.restart()
182182

183183
while not homotopy.should_stop:
@@ -187,10 +187,15 @@ def optimize_acqf_homotopy(
187187
q=1,
188188
options=options,
189189
batch_initial_conditions=candidates,
190+
raw_samples=q_raw_samples,
190191
**shared_optimize_acqf_kwargs,
191192
)
192193
homotopy.step()
193194

195+
# Set raw_samples to None such that pruned restarts are not repopulated
196+
# at each step in the homotopy.
197+
q_raw_samples = None
198+
194199
# Prune candidates
195200
candidates = prune_candidates(
196201
candidates=candidates.squeeze(1),
@@ -204,6 +209,7 @@ def optimize_acqf_homotopy(
204209
bounds=bounds,
205210
q=1,
206211
options=final_options,
212+
raw_samples=q_raw_samples,
207213
batch_initial_conditions=candidates,
208214
**shared_optimize_acqf_kwargs,
209215
)

0 commit comments

Comments
 (0)