@@ -109,12 +109,41 @@ def __post_init__(self) -> None:
109
109
"3-dimensional. Its shape is "
110
110
f"{ batch_initial_conditions_shape } ."
111
111
)
112
+
112
113
if batch_initial_conditions_shape [- 1 ] != d :
113
114
raise ValueError (
114
115
f"batch_initial_conditions.shape[-1] must be { d } . The "
115
116
f"shape is { batch_initial_conditions_shape } ."
116
117
)
117
118
119
+ if len (batch_initial_conditions_shape ) == 2 :
120
+ warnings .warn (
121
+ "If using a 2-dim `batch_initial_conditions` botorch will "
122
+ "default to old behavior of ignoring `num_restarts` and just "
123
+ "use the given `batch_initial_conditions` by setting "
124
+ "`raw_samples` to None." ,
125
+ RuntimeWarning ,
126
+ )
127
+ # Use object.__setattr__ to bypass immutability and set a value
128
+ object .__setattr__ (self , "raw_samples" , None )
129
+
130
+ if (
131
+ len (batch_initial_conditions_shape ) == 3
132
+ and batch_initial_conditions_shape [0 ] < self .num_restarts
133
+ and batch_initial_conditions_shape [- 2 ] != self .q
134
+ ):
135
+ warnings .warn (
136
+ "If using a 3-dim `batch_initial_conditions` where the "
137
+ "first dimension is less than `num_restarts` and the second "
138
+ "dimension is not equal to `q`, botorch will default to "
139
+ "old behavior of ignoring `num_restarts` and just use the "
140
+ "given `batch_initial_conditions` by setting `raw_samples` "
141
+ "to None." ,
142
+ RuntimeWarning ,
143
+ )
144
+ # Use object.__setattr__ to bypass immutability and set a value
145
+ object .__setattr__ (self , "raw_samples" , None )
146
+
118
147
elif self .ic_generator is None :
119
148
if self .nonlinear_inequality_constraints is not None :
120
149
raise RuntimeError (
@@ -126,6 +155,7 @@ def __post_init__(self) -> None:
126
155
"Must specify `raw_samples` when "
127
156
"`batch_initial_conditions` is None`."
128
157
)
158
+
129
159
if self .fixed_features is not None and any (
130
160
(k < 0 for k in self .fixed_features )
131
161
):
@@ -248,25 +278,54 @@ def _optimize_acqf_sequential_q(
248
278
if base_X_pending is not None
249
279
else candidates
250
280
)
251
- logger .info (f"Generated sequential candidate { i + 1 } of { opt_inputs .q } " )
281
+ logger .info (f"Generated sequential candidate { i + 1 } of { opt_inputs .q } " )
252
282
opt_inputs .acq_function .set_X_pending (base_X_pending )
253
283
return candidates , torch .stack (acq_value_list )
254
284
255
285
286
+ def _combine_initial_conditions (
287
+ provided_initial_conditions : Tensor | None = None ,
288
+ generated_initial_conditions : Tensor | None = None ,
289
+ dim = 0 ,
290
+ ) -> Tensor :
291
+ if (
292
+ provided_initial_conditions is not None
293
+ and generated_initial_conditions is not None
294
+ ):
295
+ return torch .cat (
296
+ [provided_initial_conditions , generated_initial_conditions ], dim = dim
297
+ )
298
+ elif provided_initial_conditions is not None :
299
+ return provided_initial_conditions
300
+ elif generated_initial_conditions is not None :
301
+ return generated_initial_conditions
302
+ else :
303
+ raise ValueError (
304
+ "Either `batch_initial_conditions` or `raw_samples` must be set."
305
+ )
306
+
307
+
256
308
def _optimize_acqf_batch (opt_inputs : OptimizeAcqfInputs ) -> tuple [Tensor , Tensor ]:
257
309
options = opt_inputs .options or {}
258
310
259
- initial_conditions_provided = opt_inputs .batch_initial_conditions is not None
311
+ required_num_restarts = opt_inputs .num_restarts
312
+ provided_initial_conditions = opt_inputs .batch_initial_conditions
313
+ generated_initial_conditions = None
260
314
261
- if initial_conditions_provided :
262
- batch_initial_conditions = opt_inputs .batch_initial_conditions
263
- else :
264
- # pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
265
- batch_initial_conditions = opt_inputs .get_ic_generator ()(
315
+ if (
316
+ provided_initial_conditions is not None
317
+ and len (provided_initial_conditions .shape ) == 3
318
+ ):
319
+ required_num_restarts -= provided_initial_conditions .shape [0 ]
320
+
321
+ if opt_inputs .raw_samples is not None and required_num_restarts > 0 :
322
+ # pyre-ignore[28]: Unexpected keyword argument `acq_function`
323
+ # to anonymous call.
324
+ generated_initial_conditions = opt_inputs .get_ic_generator ()(
266
325
acq_function = opt_inputs .acq_function ,
267
326
bounds = opt_inputs .bounds ,
268
327
q = opt_inputs .q ,
269
- num_restarts = opt_inputs . num_restarts ,
328
+ num_restarts = required_num_restarts ,
270
329
raw_samples = opt_inputs .raw_samples ,
271
330
fixed_features = opt_inputs .fixed_features ,
272
331
options = options ,
@@ -275,6 +334,11 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor
275
334
** opt_inputs .ic_gen_kwargs ,
276
335
)
277
336
337
+ batch_initial_conditions = _combine_initial_conditions (
338
+ provided_initial_conditions = provided_initial_conditions ,
339
+ generated_initial_conditions = generated_initial_conditions ,
340
+ )
341
+
278
342
batch_limit : int = options .get (
279
343
"batch_limit" ,
280
344
(
@@ -325,7 +389,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
325
389
opt_warnings += ws
326
390
batch_candidates_list .append (batch_candidates_curr )
327
391
batch_acq_values_list .append (batch_acq_values_curr )
328
- logger .info (f"Generated candidate batch { i + 1 } of { len (batched_ics )} ." )
392
+ logger .info (f"Generated candidate batch { i + 1 } of { len (batched_ics )} ." )
329
393
330
394
batch_candidates = torch .cat (batch_candidates_list )
331
395
has_scalars = batch_acq_values_list [0 ].ndim == 0
@@ -344,23 +408,24 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
344
408
first_warn_msg = (
345
409
"Optimization failed in `gen_candidates_scipy` with the following "
346
410
f"warning(s):\n { [w .message for w in ws ]} \n Because you specified "
347
- "`batch_initial_conditions`, optimization will not be retried with "
348
- "new initial conditions and will proceed with the current solution."
349
- " Suggested remediation: Try again with different "
350
- "`batch_initial_conditions`, or don't provide `batch_initial_conditions.`"
351
- if initial_conditions_provided
411
+ "`batch_initial_conditions` larger than required `num_restarts`, "
412
+ "optimization will not be retried with new initial conditions and "
413
+ "will proceed with the current solution. Suggested remediation: "
414
+ "Try again with different `batch_initial_conditions`, don't provide "
415
+ "`batch_initial_conditions`, or increase `num_restarts`."
416
+ if batch_initial_conditions is not None and required_num_restarts <= 0
352
417
else "Optimization failed in `gen_candidates_scipy` with the following "
353
418
f"warning(s):\n { [w .message for w in ws ]} \n Trying again with a new "
354
419
"set of initial conditions."
355
420
)
356
421
warnings .warn (first_warn_msg , RuntimeWarning , stacklevel = 2 )
357
422
358
- if not initial_conditions_provided :
359
- batch_initial_conditions = opt_inputs .get_ic_generator ()(
423
+ if opt_inputs . raw_samples is not None and required_num_restarts > 0 :
424
+ generated_initial_conditions = opt_inputs .get_ic_generator ()(
360
425
acq_function = opt_inputs .acq_function ,
361
426
bounds = opt_inputs .bounds ,
362
427
q = opt_inputs .q ,
363
- num_restarts = opt_inputs . num_restarts ,
428
+ num_restarts = required_num_restarts ,
364
429
raw_samples = opt_inputs .raw_samples ,
365
430
fixed_features = opt_inputs .fixed_features ,
366
431
options = options ,
@@ -369,6 +434,11 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
369
434
** opt_inputs .ic_gen_kwargs ,
370
435
)
371
436
437
+ batch_initial_conditions = _combine_initial_conditions (
438
+ provided_initial_conditions = provided_initial_conditions ,
439
+ generated_initial_conditions = generated_initial_conditions ,
440
+ )
441
+
372
442
batch_candidates , batch_acq_values , ws = _optimize_batch_candidates ()
373
443
374
444
optimization_warning_raised = any (
@@ -1177,7 +1247,7 @@ def _gen_batch_initial_conditions_local_search(
1177
1247
inequality_constraints : list [tuple [Tensor , Tensor , float ]],
1178
1248
min_points : int ,
1179
1249
max_tries : int = 100 ,
1180
- ):
1250
+ ) -> Tensor :
1181
1251
"""Generate initial conditions for local search."""
1182
1252
device = discrete_choices [0 ].device
1183
1253
dtype = discrete_choices [0 ].dtype
@@ -1197,6 +1267,58 @@ def _gen_batch_initial_conditions_local_search(
1197
1267
raise RuntimeError (f"Failed to generate at least { min_points } initial conditions" )
1198
1268
1199
1269
1270
+ def _gen_starting_points_local_search (
1271
+ discrete_choices : list [Tensor ],
1272
+ raw_samples : int ,
1273
+ batch_initial_conditions : Tensor ,
1274
+ X_avoid : Tensor ,
1275
+ inequality_constraints : list [tuple [Tensor , Tensor , float ]],
1276
+ min_points : int ,
1277
+ acq_function : AcquisitionFunction ,
1278
+ max_batch_size : int = 2048 ,
1279
+ max_tries : int = 100 ,
1280
+ ) -> Tensor :
1281
+ required_min_points = min_points
1282
+ provided_X0 = None
1283
+ generated_X0 = None
1284
+
1285
+ if batch_initial_conditions is not None :
1286
+ provided_X0 = _filter_invalid (
1287
+ X = batch_initial_conditions .squeeze (1 ), X_avoid = X_avoid
1288
+ )
1289
+ provided_X0 = _filter_infeasible (
1290
+ X = provided_X0 , inequality_constraints = inequality_constraints
1291
+ ).unsqueeze (1 )
1292
+ required_min_points -= batch_initial_conditions .shape [0 ]
1293
+
1294
+ if required_min_points > 0 :
1295
+ generated_X0 = _gen_batch_initial_conditions_local_search (
1296
+ discrete_choices = discrete_choices ,
1297
+ raw_samples = raw_samples ,
1298
+ X_avoid = X_avoid ,
1299
+ inequality_constraints = inequality_constraints ,
1300
+ min_points = min_points ,
1301
+ max_tries = max_tries ,
1302
+ )
1303
+
1304
+ # pick the best starting points
1305
+ with torch .no_grad ():
1306
+ acqvals_init = _split_batch_eval_acqf (
1307
+ acq_function = acq_function ,
1308
+ X = generated_X0 .unsqueeze (1 ),
1309
+ max_batch_size = max_batch_size ,
1310
+ ).unsqueeze (- 1 )
1311
+
1312
+ generated_X0 = generated_X0 [
1313
+ acqvals_init .topk (k = min_points , largest = True , dim = 0 ).indices
1314
+ ]
1315
+
1316
+ return _combine_initial_conditions (
1317
+ provided_initial_conditions = provided_X0 if provided_X0 is not None else None ,
1318
+ generated_initial_conditions = generated_X0 if generated_X0 is not None else None ,
1319
+ )
1320
+
1321
+
1200
1322
def optimize_acqf_discrete_local_search (
1201
1323
acq_function : AcquisitionFunction ,
1202
1324
discrete_choices : list [Tensor ],
@@ -1207,6 +1329,7 @@ def optimize_acqf_discrete_local_search(
1207
1329
X_avoid : Tensor | None = None ,
1208
1330
batch_initial_conditions : Tensor | None = None ,
1209
1331
max_batch_size : int = 2048 ,
1332
+ max_tries : int = 100 ,
1210
1333
unique : bool = True ,
1211
1334
) -> tuple [Tensor , Tensor ]:
1212
1335
r"""Optimize acquisition function over a lattice.
@@ -1238,6 +1361,8 @@ def optimize_acqf_discrete_local_search(
1238
1361
max_batch_size: The maximum number of choices to evaluate in batch.
1239
1362
A large limit can cause excessive memory usage if the model has
1240
1363
a large training set.
1364
+ max_tries: Maximum number of iterations to try when generating initial
1365
+ conditions.
1241
1366
unique: If True return unique choices, o/w choices may be repeated
1242
1367
(only relevant if `q > 1`).
1243
1368
@@ -1247,6 +1372,16 @@ def optimize_acqf_discrete_local_search(
1247
1372
- a `q x d`-dim tensor of generated candidates.
1248
1373
- an associated acquisition value.
1249
1374
"""
1375
+ if batch_initial_conditions is not None :
1376
+ if not (
1377
+ len (batch_initial_conditions .shape ) == 3
1378
+ and batch_initial_conditions .shape [- 2 ] == 1
1379
+ ):
1380
+ raise ValueError (
1381
+ "batch_initial_conditions must have shape `n x 1 x d` if "
1382
+ f"given (recieved { batch_initial_conditions } )."
1383
+ )
1384
+
1250
1385
candidate_list = []
1251
1386
base_X_pending = acq_function .X_pending if q > 1 else None
1252
1387
base_X_avoid = X_avoid
@@ -1259,27 +1394,18 @@ def optimize_acqf_discrete_local_search(
1259
1394
inequality_constraints = inequality_constraints or []
1260
1395
for i in range (q ):
1261
1396
# generate some starting points
1262
- if i == 0 and batch_initial_conditions is not None :
1263
- X0 = _filter_invalid (X = batch_initial_conditions .squeeze (1 ), X_avoid = X_avoid )
1264
- X0 = _filter_infeasible (
1265
- X = X0 , inequality_constraints = inequality_constraints
1266
- ).unsqueeze (1 )
1267
- else :
1268
- X_init = _gen_batch_initial_conditions_local_search (
1269
- discrete_choices = discrete_choices ,
1270
- raw_samples = raw_samples ,
1271
- X_avoid = X_avoid ,
1272
- inequality_constraints = inequality_constraints ,
1273
- min_points = num_restarts ,
1274
- )
1275
- # pick the best starting points
1276
- with torch .no_grad ():
1277
- acqvals_init = _split_batch_eval_acqf (
1278
- acq_function = acq_function ,
1279
- X = X_init .unsqueeze (1 ),
1280
- max_batch_size = max_batch_size ,
1281
- ).unsqueeze (- 1 )
1282
- X0 = X_init [acqvals_init .topk (k = num_restarts , largest = True , dim = 0 ).indices ]
1397
+ X0 = _gen_starting_points_local_search (
1398
+ discrete_choices = discrete_choices ,
1399
+ raw_samples = raw_samples ,
1400
+ batch_initial_conditions = batch_initial_conditions ,
1401
+ X_avoid = X_avoid ,
1402
+ inequality_constraints = inequality_constraints ,
1403
+ min_points = num_restarts ,
1404
+ acq_function = acq_function ,
1405
+ max_batch_size = max_batch_size ,
1406
+ max_tries = max_tries ,
1407
+ )
1408
+ batch_initial_conditions = None
1283
1409
1284
1410
# optimize from the best starting points
1285
1411
best_xs = torch .zeros (len (X0 ), dim , device = device , dtype = dtype )
0 commit comments