@@ -114,6 +114,19 @@ def __post_init__(self) -> None:
114
114
f"shape is { batch_initial_conditions_shape } ."
115
115
)
116
116
117
+ if (
118
+ self .raw_samples is not None
119
+ and (self .raw_samples - batch_initial_conditions_shape [- 2 ]) > 0
120
+ and len (batch_initial_conditions_shape ) == 3
121
+ and self .num_restarts is not None
122
+ and self .num_restarts != batch_initial_conditions_shape [0 ]
123
+ ):
124
+ raise ValueError (
125
+ "If using `batch_initial_conditions` together with `raw_samples`, "
126
+ "the first repeat dimension of `batch_initial_conditions` must "
127
+ "match `num_restarts`."
128
+ )
129
+
117
130
elif self .ic_generator is None :
118
131
if self .nonlinear_inequality_constraints is not None :
119
132
raise RuntimeError (
@@ -253,22 +266,44 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor
253
266
254
267
initial_conditions_provided = opt_inputs .batch_initial_conditions is not None
255
268
269
+ required_raw_samples = opt_inputs .raw_samples
256
270
if initial_conditions_provided :
257
- batch_initial_conditions = opt_inputs .batch_initial_conditions
271
+ provided_initial_conditions = opt_inputs .batch_initial_conditions
272
+ if opt_inputs .raw_samples is not None :
273
+ required_raw_samples -= provided_initial_conditions .shape [- 2 ]
258
274
else :
275
+ provided_initial_conditions = None
276
+
277
+ if required_raw_samples is not None and required_raw_samples > 0 :
259
278
# pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
260
- batch_initial_conditions = opt_inputs .get_ic_generator ()(
279
+ generated_initial_conditions = opt_inputs .get_ic_generator ()(
261
280
acq_function = opt_inputs .acq_function ,
262
281
bounds = opt_inputs .bounds ,
263
282
q = opt_inputs .q ,
264
283
num_restarts = opt_inputs .num_restarts ,
265
- raw_samples = opt_inputs . raw_samples ,
284
+ raw_samples = required_raw_samples ,
266
285
fixed_features = opt_inputs .fixed_features ,
267
286
options = options ,
268
287
inequality_constraints = opt_inputs .inequality_constraints ,
269
288
equality_constraints = opt_inputs .equality_constraints ,
270
289
** opt_inputs .ic_gen_kwargs ,
271
290
)
291
+ else :
292
+ generated_initial_conditions = None
293
+
294
+ if provided_initial_conditions is not None and generated_initial_conditions is not None :
295
+ provided_initial_conditions = provided_initial_conditions .repeat (
296
+ opt_inputs .num_restarts , * ([1 ] * (provided_initial_conditions .dim ()- 1 ))
297
+ )
298
+ batch_initial_conditions = torch .cat (
299
+ [provided_initial_conditions , generated_initial_conditions ], dim = - 2
300
+ ) # should this be shuffled?
301
+ elif provided_initial_conditions is not None :
302
+ batch_initial_conditions = provided_initial_conditions
303
+ elif generated_initial_conditions is not None :
304
+ batch_initial_conditions = generated_initial_conditions
305
+ else :
306
+ raise ValueError ("Either `batch_initial_conditions` or `raw_samples` must be set." )
272
307
273
308
batch_limit : int = options .get (
274
309
"batch_limit" ,
@@ -339,31 +374,39 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
339
374
first_warn_msg = (
340
375
"Optimization failed in `gen_candidates_scipy` with the following "
341
376
f"warning(s):\n { [w .message for w in ws ]} \n Because you specified "
342
- "`batch_initial_conditions`, optimization will not be retried with "
343
- "new initial conditions and will proceed with the current solution."
344
- " Suggested remediation: Try again with different "
345
- "`batch_initial_conditions`, or don't provide `batch_initial_conditions.`"
346
- if initial_conditions_provided
377
+ "`batch_initial_conditions`>`raw_samples`, optimization will not "
378
+ "be retried with new initial conditions and will proceed with the "
379
+ "current solution. Suggested remediation: Try again with different "
380
+ "`batch_initial_conditions`, don't provide `batch_initial_conditions, "
381
+ "or increase `raw_samples`.`"
382
+ if required_raw_samples is not None and required_raw_samples > 0
347
383
else "Optimization failed in `gen_candidates_scipy` with the following "
348
384
f"warning(s):\n { [w .message for w in ws ]} \n Trying again with a new "
349
385
"set of initial conditions."
350
386
)
351
387
warnings .warn (first_warn_msg , RuntimeWarning , stacklevel = 2 )
352
388
353
- if not initial_conditions_provided :
354
- batch_initial_conditions = opt_inputs .get_ic_generator ()(
389
+ if required_raw_samples is not None and required_raw_samples > 0 :
390
+ generated_initial_conditions = opt_inputs .get_ic_generator ()(
355
391
acq_function = opt_inputs .acq_function ,
356
392
bounds = opt_inputs .bounds ,
357
393
q = opt_inputs .q ,
358
394
num_restarts = opt_inputs .num_restarts ,
359
- raw_samples = opt_inputs . raw_samples ,
395
+ raw_samples = required_raw_samples ,
360
396
fixed_features = opt_inputs .fixed_features ,
361
397
options = options ,
362
398
inequality_constraints = opt_inputs .inequality_constraints ,
363
399
equality_constraints = opt_inputs .equality_constraints ,
364
400
** opt_inputs .ic_gen_kwargs ,
365
401
)
366
402
403
+ if provided_initial_conditions is not None :
404
+ batch_initial_conditions = torch .cat (
405
+ [provided_initial_conditions , generated_initial_conditions ], dim = - 2
406
+ ) # should this be shuffled?
407
+ else :
408
+ batch_initial_conditions = generated_initial_conditions
409
+
367
410
batch_candidates , batch_acq_values , ws = _optimize_batch_candidates ()
368
411
369
412
optimization_warning_raised = any (
@@ -1199,11 +1242,46 @@ def optimize_acqf_discrete_local_search(
1199
1242
inequality_constraints = inequality_constraints or []
1200
1243
for i in range (q ):
1201
1244
# generate some starting points
1202
- if i == 0 and batch_initial_conditions is not None :
1203
- X0 = _filter_invalid (X = batch_initial_conditions .squeeze (1 ), X_avoid = X_avoid )
1204
- X0 = _filter_infeasible (
1205
- X = X0 , inequality_constraints = inequality_constraints
1206
- ).unsqueeze (1 )
1245
+ if i == 0 :
1246
+
1247
+ if batch_initial_conditions is not None :
1248
+ provided_X0 = _filter_invalid (X = batch_initial_conditions .squeeze (1 ), X_avoid = X_avoid )
1249
+ provided_X0 = _filter_infeasible (
1250
+ X = provided_X0 , inequality_constraints = inequality_constraints
1251
+ ).unsqueeze (1 )
1252
+ if raw_samples is not None :
1253
+ required_raw_samples = raw_samples - batch_initial_conditions .shape [- 2 ]
1254
+ else :
1255
+ required_raw_samples = raw_samples
1256
+ provided_X0 = None
1257
+
1258
+ if required_raw_samples > 0 :
1259
+ X_init = _gen_batch_initial_conditions_local_search (
1260
+ discrete_choices = discrete_choices ,
1261
+ raw_samples = required_raw_samples ,
1262
+ X_avoid = X_avoid ,
1263
+ inequality_constraints = inequality_constraints ,
1264
+ min_points = num_restarts ,
1265
+ )
1266
+ # pick the best starting points
1267
+ with torch .no_grad ():
1268
+ acqvals_init = _split_batch_eval_acqf (
1269
+ acq_function = acq_function ,
1270
+ X = X_init .unsqueeze (1 ),
1271
+ max_batch_size = max_batch_size ,
1272
+ ).unsqueeze (- 1 )
1273
+ generated_X0 = X_init [acqvals_init .topk (k = num_restarts , largest = True , dim = 0 ).indices ]
1274
+
1275
+ if provided_X0 is not None and generated_X0 is not None :
1276
+ provided_X0 = provided_X0 .repeat (num_restarts , * ([1 ] * (provided_X0 .ndim - 1 )))
1277
+ X0 = torch .cat ([provided_X0 , generated_X0 ], dim = - 2 )
1278
+ elif provided_X0 is not None :
1279
+ X0 = provided_X0
1280
+ elif generated_X0 is not None :
1281
+ X0 = generated_X0
1282
+ else :
1283
+ raise ValueError ("Either `batch_initial_conditions` or `raw_samples` must be set." )
1284
+
1207
1285
else :
1208
1286
X_init = _gen_batch_initial_conditions_local_search (
1209
1287
discrete_choices = discrete_choices ,
0 commit comments