diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 6f93c668a4..c11eaf1a7e 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -18,12 +18,21 @@ International Conference on Artificial Intelligence and Statistics. PMLR, 2021, http://proceedings.mlr.press/v130/eriksson21a.html +.. [song2024vizier] + Song, Xingyou and others. The vizier gaussian process bandit algorithm + arXiv preprint arXiv:2408.11527. + https://arxiv.org/abs/2408.11527 + """ from __future__ import annotations from abc import ABC, abstractmethod from collections import OrderedDict +from itertools import product + +import numpy as np +import scipy.stats as stats import torch from botorch.models.transforms.utils import ( @@ -276,20 +285,22 @@ def forward( "the `batch_shape` argument to `Standardize`, but got " f"Y.shape[:-2]={Y.shape[:-2]}." ) + if Y.size(-1) != self._m: raise RuntimeError( f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected " f"{self._m}." ) + if Y.shape[-2] < 1: raise ValueError(f"Can't standardize with no observations. {Y.shape=}.") - elif Y.shape[-2] == 1: stdvs = torch.ones( (*Y.shape[:-2], 1, Y.shape[-1]), dtype=Y.dtype, device=Y.device ) else: stdvs = Y.std(dim=-2, keepdim=True) + stdvs = stdvs.where(stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0)) means = Y.mean(dim=-2, keepdim=True) if self._outputs is not None: @@ -823,3 +834,643 @@ def untransform_posterior(self, posterior: Posterior) -> TransformedPosterior: posterior=posterior, sample_transform=lambda x: x.sign() * x.abs().expm1(), ) + + +def _nanmax( + tensor: Tensor, dim: int | None = None, keepdim: bool = False +) -> Tensor | tuple[Tensor, Tensor]: + """Compute the maximum of a tensor, ignoring NaNs.""" + min_value = torch.finfo(tensor.dtype).min + if dim is None: + return tensor.nan_to_num(min_value).max() + return tensor.nan_to_num(min_value).max(dim=dim, keepdim=keepdim) + + +def _nanmin( + tensor: Tensor, dim: int | None = None, keepdim: bool = False +) -> Tensor | tuple[Tensor, Tensor]: + """Compute the minimum of a tensor, ignoring NaNs.""" + max_value = torch.finfo(tensor.dtype).max + if dim is None: + return tensor.nan_to_num(max_value).min() + return tensor.nan_to_num(max_value).min(dim=dim, keepdim=keepdim) + + +def _check_batched_output(Y: Tensor, batch_shape: Tensor, m: int) -> None: + """Utility for common output transform checks.""" + if Y.shape[:-2] != batch_shape: + raise RuntimeError( + f"Expected Y.shape[:-2] to be {batch_shape}, matching " + "the `batch_shape` argument to the `OutcomeTransform`, but got " + f"Y.shape[:-2]={Y.shape[:-2]}." + ) + + if Y.shape[-1] != m: + raise ValueError(f"Expected Y.shape[-1] to be {m}, but got {Y.shape[-1]}.") + + if Y.shape[-2] < 1: + raise ValueError(f"Can't transform with no observations. {Y.shape=}.") + + +class InfeasibleTransform(OutcomeTransform): + """Transforms infeasible (NaN) values to feasible values. + + Inspired by output-space transformations in Vizier [song2024vizier]_. + """ + + def __init__(self, m: int, batch_shape: torch.Size = torch.Size()) -> None: + """Transforms infeasible (NaN) values to feasible values. + + Args: + batch_shape: The batch shape of the outcomes. + """ + super().__init__() + self._m = m + self._batch_shape = batch_shape + self.register_buffer("_shift", torch.zeros([*batch_shape, m])) + self.register_buffer("warped_bad_value", torch.zeros([*batch_shape, m])) + self.register_buffer("_is_trained", torch.tensor(False)) + + def forward( + self, Y: Tensor, Yvar: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + """Transform the outcomes by handling NaN values. + + Args: + Y: A `batch_shape x n x m`-dim tensor of training targets. + Yvar: A `batch_shape x n x m`-dim tensor of observation noises + associated with the training targets (if applicable). + + Returns: + A two-tuple with the transformed outcomes: + - The transformed outcome observations. + - The transformed observation noise (if applicable). + """ + _check_batched_output(Y, self._batch_shape) + + if self.training: + if torch.isnan(Y).all(dim=-2).any(): + raise RuntimeError("For at least one batch, all outcomes are NaN") + + labels_range = _nanmax(Y, dim=-2).values - _nanmin(Y, dim=-2).values + warped_bad_value = _nanmin(Y, dim=-2).values - (0.5 * labels_range + 1) + num_feasible = Y.shape[-2] - torch.isnan(Y).sum(dim=-2) + + # Estimate the relative frequency of feasible points + p_feasible = (0.5 + num_feasible) / (1 + Y.shape[-2]) + + self.warped_bad_value = warped_bad_value + self._shift = -torch.nanmean(Y, dim=-2) * p_feasible - warped_bad_value * ( + 1 - p_feasible + ) + + self._is_trained = torch.tensor(True) + + # Expand warped_bad_value to match Y's shape + expanded_bad_value = self.warped_bad_value.unsqueeze(-2).expand_as(Y) + expanded_shift = self._shift.unsqueeze(-2).expand_as(Y) + Y = torch.where(torch.isnan(Y), expanded_bad_value, Y + expanded_shift) + + if Yvar is not None: + Yvar = torch.where(torch.isnan(Y), torch.tensor(0.0), Yvar) + + return Y, Yvar + + def untransform( + self, Y: Tensor, Yvar: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + """Un-transform the outcomes. + + Args: + Y: A `batch_shape x n x m`-dim tensor of transformed targets. + Yvar: A `batch_shape x n x m`-dim tensor of transformed observation + noises associated with the targets (if applicable). + + Returns: + A two-tuple with the un-transformed outcomes: + - The un-transformed outcome observations. + - The un-transformed observation noise (if applicable). + """ + if not self._is_trained: + raise RuntimeError( + "forward() needs to be called before untransform() is called." + ) + + # Expand shift to match Y's shape + expanded_shift = self._shift.unsqueeze(-2).expand_as(Y) + Y -= expanded_shift + return Y, Yvar + + +class LogWarperTransform(OutcomeTransform): + r"""Warps an array of labels to highlight the difference between good values. + + NOTE that this warping is performed on finite values of the array and NaNs are + untouched. + + Inspired by output-space transformations in Vizier [song2024vizier]_. + + The log warping process consists of two transformations: + + 1. Normalization: + + .. math:: + + \hat{y} = \frac{y_{\max} - y}{y_{\max} - y_{\min}} + + 2. Log Warping: + + .. math:: + + \hat{y}_{\text{warped}} = 0.5 - \frac{\log(1 + (s - 1) \cdot \hat{y})}{\log(s)} + + Where: + - :math:`y` is the input value + - :math:`y_{\min}` is the minimum value in the dataset + - :math:`y_{\max}` is the maximum value in the dataset + - :math:`s` is a free parameter (default 1.5) + + """ + + def __init__( + self, m: int, batch_shape: torch.Size = torch.Size(), offset: float = 1.5 + ) -> None: + """Initialize transform. + + Args: + m: The output dimension. + batch_shape: The batch_shape of the training targets. + offset: Offset parameter for the log transformation. Larger values + of the offset parameter will lead to greater spreading of good + values. Must be > 1. + """ + super().__init__() + if offset <= 0: + raise ValueError("offset must be positive") + self._m = m + self._batch_shape = batch_shape + self.register_buffer("offset", torch.tensor(offset)) + self.register_buffer("_labels_min", torch.zeros([*batch_shape, m])) + self.register_buffer("_labels_max", torch.zeros([*batch_shape, m])) + self.register_buffer("_is_trained", torch.tensor(False)) + + def forward( + self, Y: Tensor, Yvar: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + """Transform the outcomes. + + Args: + Y: A `batch_shape x n x m`-dim tensor of training targets. + Yvar: A `batch_shape x n x m`-dim tensor of observation noises + associated with the training targets (if applicable). + + Returns: + A two-tuple with the transformed outcomes: + - The transformed outcome observations. + - The transformed observation noise (if applicable). + """ + _check_batched_output(Y, self._batch_shape, self._m) + + if Yvar is not None: + raise NotImplementedError( + "LogWarperTransform does not support transforming observation noise" + ) + + if self.training: + if torch.isnan(Y).all(dim=-2).any(): + raise RuntimeError("For at least one batch, all outcomes are NaN") + + self._labels_min = _nanmin(Y, dim=-2).values + self._labels_max = _nanmax(Y, dim=-2).values + self._is_trained = torch.tensor(True) + + expanded_labels_min = self._labels_min.unsqueeze(-2).expand_as(Y) + expanded_labels_max = self._labels_max.unsqueeze(-2).expand_as(Y) + + # Calculate normalized difference + norm_diff = (expanded_labels_max - Y) / ( + expanded_labels_max - expanded_labels_min + ) + Y_transformed = 0.5 - ( + torch.log1p(norm_diff * (self.offset - 1)) / torch.log(self.offset) + ) + + return Y_transformed, Yvar + + def untransform( + self, Y: Tensor, Yvar: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + """Un-transform the outcomes. + + Args: + Y: A `batch_shape x n x m`-dim tensor of transformed targets. + Yvar: A `batch_shape x n x m`-dim tensor of transformed observation + noises associated with the targets (if applicable). + + Returns: + A two-tuple with the un-transformed outcomes: + - The un-transformed outcome observations. + - The un-transformed observation noise (if applicable). + """ + if not self._is_trained: + raise RuntimeError("forward() needs to be called before untransform()") + + if Yvar is not None: + raise NotImplementedError( + "LogWarperTransform does not support untransforming observation noise" + ) + + expanded_labels_min = self._labels_min.unsqueeze(-2).expand_as(Y) + expanded_labels_max = self._labels_max.unsqueeze(-2).expand_as(Y) + + Y_untransformed = expanded_labels_max - ( + (torch.exp(torch.log(self.offset) * (0.5 - Y)) - 1) + * (expanded_labels_max - expanded_labels_min) + / (self.offset - 1) + ) + + return Y_untransformed, Yvar + + +class HalfRankTransform(OutcomeTransform): + """Warps half of the outcomes to fit into a Gaussian distribution. + + This transform warps values below the median to follow a Gaussian distribution while + leaving values above the median unchanged. NaN values are preserved. + + Inspired by output-space transformations in Vizier [song2024vizier]_. + """ + + def __init__(self, m: int, batch_shape: torch.Size = torch.Size()) -> None: + """Initialize transform. + + Args: + m: The output dimension. + batch_shape: The batch_shape of the training targets. + """ + super().__init__() + self._m = m + self._batch_shape = batch_shape + self.register_buffer("_original_label_medians", torch.zeros([*batch_shape, m])) + self.register_buffer("_is_trained", torch.tensor(False)) + + # TODO these are ragged tensors, we should use a better data structure such + # that they are saved to the state_dict + self._unique_labels = {} + self._warped_labels = {} + + def _get_std_above_median(self, unique_y: Tensor, y_median: Tensor) -> Tensor: + # Estimate std of good half + good_half = unique_y[unique_y >= y_median] + std = torch.sqrt(((good_half - y_median) ** 2).mean()) + + if std == 0: + std = torch.sqrt(((unique_y - y_median) ** 2).mean()) + + if torch.isnan(std): + std = torch.abs(unique_y - y_median).mean() + + return std + + def forward( + self, Y: Tensor, Yvar: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + """Transform the outcomes. + + Args: + Y: A `batch_shape x n x m`-dim tensor of training targets. + Yvar: A `batch_shape x n x m`-dim tensor of observation noises + associated with the training targets (if applicable). + + Returns: + A two-tuple with the transformed outcomes: + - The transformed outcome observations. + - The transformed observation noise (if applicable). + """ + if Yvar is not None: + raise NotImplementedError( + "HalfRankTransform does not support transforming observation noise" + ) + + _check_batched_output(Y, self._batch_shape, self._m) + Y_transformed = Y.clone() + + if self.training: + if torch.isnan(Y).all(dim=-2).any(): + raise RuntimeError("For at least one batch, all outcomes are NaN") + + # Compute median for each batch + Y_medians = torch.nanmedian(Y, dim=-2).values + + for dim in range(Y.shape[-1]): + batch_indices = ( + product(*([m for m in range(n)] for n in self._batch_shape)) + if self._batch_shape is not None and len(self._batch_shape) > 0 + else [ # this allows it to work with no batch dim + (...,), + ] + ) + for batch_idx in batch_indices: + y_median = Y_medians[*batch_idx, dim] + y = Y_transformed[*batch_idx, :, dim] + + # Get finite values and their ranks for each batch + is_finite_mask = ~torch.isnan(y) + + # TODO: this is annoying but torch.unique doesn't support + # returning indices + np_unique_y, np_unique_indices = np.unique( + y[is_finite_mask].numpy(force=True), return_index=True + ) + ranks = stats.rankdata(y.numpy(force=True), method="dense") + + unique_y = torch.from_numpy(np_unique_y).to(y.device) + unique_indices = torch.from_numpy(np_unique_indices).to(y.device) + ranks = torch.from_numpy(ranks).to(y.device) + + # Calculate rank quantiles + dedup_median_index = torch.searchsorted(unique_y, y_median) + denominator = 2 * dedup_median_index + ( + unique_y[dedup_median_index] == y_median + ) + rank_quantile = (ranks - 0.5) / denominator + + y_above_median_std = self._get_std_above_median(unique_y, y_median) + + # Apply transformation + rank_ppf = ( + torch.erfinv(2 * rank_quantile - 1) + * y_above_median_std + * torch.sqrt(torch.tensor(2.0)) + ) + Y_transformed[*batch_idx, :, dim] = torch.where( + y < y_median, + rank_ppf + y_median, + Y_transformed[*batch_idx, :, dim], + ) + + # save intermediate values for untransform + self._original_label_medians[*batch_idx, dim] = y_median + self._unique_labels[(*batch_idx, dim)] = unique_y + self._warped_labels[(*batch_idx, dim)] = Y_transformed[ + *batch_idx, :, dim + ][is_finite_mask][unique_indices] + + self._is_trained = torch.tensor(True) + return Y_transformed, Yvar + + for dim in range(Y.shape[-1]): + batch_indices = ( + product(*([m for m in range(n)] for n in self._batch_shape)) + if self._batch_shape is not None and len(self._batch_shape) > 0 + else [ # this allows it to work with no batch dim + (...,), + ] + ) + for batch_idx in batch_indices: + y_median = self._original_label_medians[*batch_idx, dim] + y = Y[*batch_idx, :, dim] + warped_labels: torch.Tensor = self._warped_labels[(*batch_idx, dim)] + unique_labels: torch.Tensor = self._unique_labels[(*batch_idx, dim)] + + # Process values below median + below_median = y < self._original_label_medians[*batch_idx, dim] + if below_median.any(): + # Find nearest original values and perform lookup + original_idx = torch.searchsorted(unique_labels, y[below_median]) + + # Create indices for neighboring values + left_idx = torch.clamp(original_idx - 1, min=0) + right_idx = torch.clamp(original_idx + 1, max=len(unique_labels)) + + # Gather neighboring values + candidates = torch.stack( + [ + unique_labels[left_idx], + unique_labels[original_idx], + unique_labels[right_idx], + ], + dim=-1, + ) + + # Find nearest original values and perform lookup + best_idx = torch.argmin( + torch.abs(candidates - y[below_median].unsqueeze(-1)), dim=-1 + ) + + lookup_mask = torch.isclose( + candidates[torch.arange(len(best_idx)), best_idx], + y[below_median], + ) + full_lookup_mask = torch.full_like(below_median, False) + below_median_indices = torch.where(below_median)[0] + lookup_indices = below_median_indices[lookup_mask] + full_lookup_mask[lookup_indices] = True + full_lookup_values = torch.zeros_like(Y[*batch_idx, :, dim]) + full_lookup_values[full_lookup_mask] = warped_labels[ + original_idx[lookup_mask] + ] + Y_transformed[*batch_idx, :, dim] = torch.where( + full_lookup_mask, + full_lookup_values, + Y_transformed[*batch_idx, :, dim], + ) + + # if the value is below the original minimum, we need to + # extrapolate outside the range + extrapolate_mask = y < unique_labels[0] + extrapolated_values = warped_labels[0] - ( + y[extrapolate_mask] - unique_labels[0] + ).abs() / (unique_labels.max() - unique_labels.min()) * ( + warped_labels.max() - warped_labels.min() + ) + full_extrapolated_values = torch.zeros_like(Y[*batch_idx, :, dim]) + full_extrapolated_values[extrapolate_mask] = extrapolated_values + Y_transformed[*batch_idx, :, dim] = torch.where( + extrapolate_mask, + full_extrapolated_values, + Y_transformed[*batch_idx, :, dim], + ) + + # otherwise, interpolate + neither_extrapolate_nor_lookup = ~( + (y[below_median] < unique_labels[0]) | lookup_mask + ) + y_neither_extrapolate_nor_lookup = y[below_median][ + neither_extrapolate_nor_lookup + ] + warped_idx_neither_extrapolate_nor_lookup = original_idx[ + neither_extrapolate_nor_lookup + ] + + lower_idx = (warped_idx_neither_extrapolate_nor_lookup - 1,) + upper_idx = (warped_idx_neither_extrapolate_nor_lookup,) + + original_gap = unique_labels[upper_idx] - unique_labels[lower_idx] + warped_gap = warped_labels[upper_idx] - warped_labels[lower_idx] + + full_interpolated_mask = torch.full_like(below_median, False) + below_median_indices = torch.where(below_median)[0] + interpolated_indices = below_median_indices[ + neither_extrapolate_nor_lookup + ] + full_interpolated_mask[interpolated_indices] = True + + full_interpolated_values = torch.zeros_like( + Y_transformed[*batch_idx, :, dim] + ) + full_interpolated_values[full_interpolated_mask] = torch.where( + original_gap > 0, + warped_labels[lower_idx] + + (y_neither_extrapolate_nor_lookup - unique_labels[lower_idx]) + / original_gap + * warped_gap, + warped_labels[lower_idx], + ) + + Y_transformed[*batch_idx, :, dim] = torch.where( + full_interpolated_mask, + full_interpolated_values, + Y_transformed[*batch_idx, :, dim], + ) + + return Y_transformed, Yvar + + def untransform( + self, Y: Tensor, Yvar: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + """Un-transform the outcomes. + + Args: + Y: A `batch_shape x n x m`-dim tensor of transformed targets. + Yvar: A `batch_shape x n x m`-dim tensor of transformed observation + noises associated with the targets (if applicable). + + Returns: + A two-tuple with the un-transformed outcomes: + - The un-transformed outcome observations. + - The un-transformed observation noise (if applicable). + """ + if not self._is_trained: + raise RuntimeError("forward() needs to be called before untransform()") + + if Yvar is not None: + raise NotImplementedError( + "HalfRankTransform does not support untransforming observation noise" + ) + + Y_utf = Y.clone() + + for dim in range(Y.shape[-1]): + batch_indices = ( + product(*(range(n) for n in self._batch_shape)) + if self._batch_shape is not None and len(self._batch_shape) > 0 + else [ # this allows it to work with no batch dim + (...,), + ] + ) + for batch_idx in batch_indices: + y = Y_utf[*batch_idx, :, dim].clone() + unique_labels = self._unique_labels[(*batch_idx, dim)] + warped_labels = self._warped_labels[(*batch_idx, dim)] + + # Process values below median + below_median = y < self._original_label_medians[*batch_idx, dim] + if below_median.any(): + # Find nearest warped values and perform lookup + warped_idx = torch.searchsorted(warped_labels, y[below_median]) + + # Create indices for neighboring values + left_idx = torch.clamp(warped_idx - 1, min=0) + right_idx = torch.clamp(warped_idx + 1, max=len(warped_labels)) + + # Gather neighboring values + candidates = torch.stack( + [ + warped_labels[left_idx], + warped_labels[warped_idx], + warped_labels[right_idx], + ], + dim=-1, + ) + + best_idx = torch.argmin( + torch.abs(candidates - y[below_median].unsqueeze(-1)), dim=-1 + ) + lookup_mask = torch.isclose( + candidates[torch.arange(len(best_idx)), best_idx], + y[below_median], + ) + full_lookup_mask = torch.full_like(below_median, False) + below_median_indices = torch.where(below_median)[0] + lookup_indices = below_median_indices[lookup_mask] + full_lookup_mask[lookup_indices] = True + full_lookup_values = torch.zeros_like(Y_utf[*batch_idx, :, dim]) + full_lookup_values[full_lookup_mask] = unique_labels[ + warped_idx[lookup_mask] + ] + Y_utf[*batch_idx, :, dim] = torch.where( + full_lookup_mask, full_lookup_values, Y_utf[*batch_idx, :, dim] + ) + + # if the value is below the warped minimum, we need to + # extrapolate outside the range + extrapolate_mask = y < warped_labels[0] + extrapolated_values = unique_labels[0] - ( + y[extrapolate_mask] - warped_labels[0] + ).abs() / (warped_labels.max() - warped_labels.min()) * ( + unique_labels.max() - unique_labels.min() + ) + full_extrapolated_values = torch.zeros_like( + Y_utf[*batch_idx, :, dim] + ) + full_extrapolated_values[extrapolate_mask] = extrapolated_values + Y_utf[*batch_idx, :, dim] = torch.where( + extrapolate_mask, + full_extrapolated_values, + Y_utf[*batch_idx, :, dim], + ) + + # otherwise, interpolate + neither_extrapolate_nor_lookup = ~( + (y[below_median] < warped_labels[0]) | lookup_mask + ) + y_neither_extrapolate_nor_lookup = y[below_median][ + neither_extrapolate_nor_lookup + ] + warped_idx_neither_extrapolate_nor_lookup = warped_idx[ + neither_extrapolate_nor_lookup + ] + + lower_idx = (warped_idx_neither_extrapolate_nor_lookup - 1,) + upper_idx = (warped_idx_neither_extrapolate_nor_lookup,) + + original_gap = unique_labels[upper_idx] - unique_labels[lower_idx] + warped_gap = warped_labels[upper_idx] - warped_labels[lower_idx] + + full_interpolated_mask = torch.full_like(below_median, False) + below_median_indices = torch.where(below_median)[0] + interpolated_indices = below_median_indices[ + neither_extrapolate_nor_lookup + ] + full_interpolated_mask[interpolated_indices] = True + + full_interpolated_values = torch.zeros_like( + Y_utf[*batch_idx, :, dim] + ) + full_interpolated_values[full_interpolated_mask] = torch.where( + warped_gap > 0, + unique_labels[lower_idx] + + (y_neither_extrapolate_nor_lookup - warped_labels[lower_idx]) + / warped_gap + * original_gap, + unique_labels[lower_idx], + ) + + Y_utf[*batch_idx, :, dim] = torch.where( + full_interpolated_mask, + full_interpolated_values, + Y_utf[*batch_idx, :, dim], + ) + + return Y_utf, Yvar diff --git a/botorch/optim/__init__.py b/botorch/optim/__init__.py index f4abe3fd87..5156bba684 100644 --- a/botorch/optim/__init__.py +++ b/botorch/optim/__init__.py @@ -22,7 +22,11 @@ LinearHomotopySchedule, LogLinearHomotopySchedule, ) -from botorch.optim.initializers import initialize_q_batch, initialize_q_batch_nonneg +from botorch.optim.initializers import ( + initialize_q_batch, + initialize_q_batch_nonneg, + initialize_q_batch_topn, +) from botorch.optim.optimize import ( gen_batch_initial_conditions, optimize_acqf, @@ -43,6 +47,7 @@ "gen_batch_initial_conditions", "initialize_q_batch", "initialize_q_batch_nonneg", + "initialize_q_batch_topn", "OptimizationResult", "OptimizationStatus", "optimize_acqf", diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index af0f918f4a..fbf975cedc 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -328,14 +328,24 @@ def gen_batch_initial_conditions( init_kwargs = {} device = bounds.device bounds_cpu = bounds.cpu() - if "eta" in options: - init_kwargs["eta"] = options.get("eta") - if options.get("nonnegative") or is_nonnegative(acq_function): + + if options.get("topn"): + init_func = initialize_q_batch_topn + init_func_opts = ["sorted", "largest"] + elif options.get("nonnegative") or is_nonnegative(acq_function): init_func = initialize_q_batch_nonneg - if "alpha" in options: - init_kwargs["alpha"] = options.get("alpha") + init_func_opts = ["alpha", "eta"] else: init_func = initialize_q_batch + init_func_opts = ["eta"] + + for opt in init_func_opts: + # default value of "largest" to "acq_function.maximize" if it exists + if opt == "largest" and hasattr(acq_function, "maximize"): + init_kwargs[opt] = acq_function.maximize + + if opt in options: + init_kwargs[opt] = options.get(opt) q = 1 if q is None else q # the dimension the samples are drawn from @@ -363,7 +373,9 @@ def gen_batch_initial_conditions( X_rnd_nlzd = torch.rand( n, q, bounds_cpu.shape[-1], dtype=bounds.dtype ) - X_rnd = bounds_cpu[0] + (bounds_cpu[1] - bounds_cpu[0]) * X_rnd_nlzd + X_rnd = unnormalize( + X_rnd_nlzd, bounds, update_constant_bounds=False + ) else: X_rnd = sample_q_batches_from_polytope( n=n, @@ -375,7 +387,8 @@ def gen_batch_initial_conditions( equality_constraints=equality_constraints, inequality_constraints=inequality_constraints, ) - # sample points around best + + # sample additional points around best if sample_around_best: X_best_rnd = sample_points_around_best( acq_function=acq_function, @@ -395,6 +408,8 @@ def gen_batch_initial_conditions( ) # Keep X on CPU for consistency & to limit GPU memory usage. X_rnd = fix_features(X_rnd, fixed_features=fixed_features).cpu() + + # Append the fixed fantasies to the randomly generated points if fixed_X_fantasies is not None: if (d_f := fixed_X_fantasies.shape[-1]) != (d_r := X_rnd.shape[-1]): raise BotorchTensorDimensionError( @@ -411,6 +426,9 @@ def gen_batch_initial_conditions( ], dim=-2, ) + + # Evaluate the acquisition function on `X_rnd` using `batch_limit` + # sized chunks. with torch.no_grad(): if batch_limit is None: batch_limit = X_rnd.shape[0] @@ -423,16 +441,22 @@ def gen_batch_initial_conditions( ], dim=0, ) + + # Downselect the initial conditions based on the acquisition function values batch_initial_conditions, _ = init_func( X=X_rnd, acq_vals=acq_vals, n=num_restarts, **init_kwargs ) batch_initial_conditions = batch_initial_conditions.to(device=device) + + # Return the initial conditions if no warnings were raised if not any(issubclass(w.category, BadInitialCandidatesWarning) for w in ws): return batch_initial_conditions + if factor < max_factor: factor += 1 if seed is not None: seed += 1 # make sure to sample different X_rnd + warnings.warn( "Unable to find non-zero acquisition function values - initial conditions " "are being selected randomly.", @@ -1057,6 +1081,56 @@ def initialize_q_batch_nonneg( return X[idcs], acq_vals[idcs] +def initialize_q_batch_topn( + X: Tensor, acq_vals: Tensor, n: int, largest: bool = True, sorted: bool = True +) -> tuple[Tensor, Tensor]: + r"""Take the top `n` initial conditions for candidate generation. + + Args: + X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim. + feature space. Typically, these are generated using qMC. + acq_vals: A tensor of `b` outcomes associated with the samples. Typically, this + is the value of the batch acquisition function to be maximized. + n: The number of initial condition to be generated. Must be less than `b`. + + Returns: + - An `n x q x d` tensor of `n` `q`-batch initial conditions. + - An `n` tensor of the corresponding acquisition values. + + Example: + >>> # To get `n=10` starting points of q-batch size `q=3` + >>> # for model with `d=6`: + >>> qUCB = qUpperConfidenceBound(model, beta=0.1) + >>> X_rnd = torch.rand(500, 3, 6) + >>> X_init, acq_init = initialize_q_batch_topn( + ... X=X_rnd, acq_vals=qUCB(X_rnd), n=10 + ... ) + + """ + n_samples = X.shape[0] + if n > n_samples: + raise RuntimeError( + f"n ({n}) cannot be larger than the number of " + f"provided samples ({n_samples})" + ) + elif n == n_samples: + return X, acq_vals + + Ystd = acq_vals.std(dim=0) + if torch.any(Ystd == 0): + warnings.warn( + "All acquisition values for raw samples points are the same for " + "at least one batch. Choosing initial conditions at random.", + BadInitialCandidatesWarning, + stacklevel=3, + ) + idcs = torch.randperm(n=n_samples, device=X.device)[:n] + return X[idcs], acq_vals[idcs] + + topk_out, topk_idcs = acq_vals.topk(n, largest=largest, sorted=sorted) + return X[topk_idcs], topk_out + + def sample_points_around_best( acq_function: AcquisitionFunction, n_discrete_points: int, diff --git a/botorch/utils/feasible_volume.py b/botorch/utils/feasible_volume.py index f3b8d2fb76..2608c03c2a 100644 --- a/botorch/utils/feasible_volume.py +++ b/botorch/utils/feasible_volume.py @@ -11,7 +11,7 @@ import botorch.models.model as model import torch from botorch.logging import _get_logger -from botorch.utils.sampling import manual_seed +from botorch.utils.sampling import manual_seed, unnormalize from torch import Tensor @@ -164,9 +164,10 @@ def estimate_feasible_volume( seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item() with manual_seed(seed=seed): - box_samples = bounds[0] + (bounds[1] - bounds[0]) * torch.rand( + samples_nlzd = torch.rand( (nsample_feature, bounds.size(1)), dtype=dtype, device=device ) + box_samples = unnormalize(samples_nlzd, bounds, update_constant_bounds=False) features, p_feature = get_feasible_samples( samples=box_samples, inequality_constraints=inequality_constraints diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index 52fe54fbb2..f914dea24d 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -98,14 +98,12 @@ def draw_sobol_samples( batch_shape = batch_shape or torch.Size() batch_size = int(torch.prod(torch.tensor(batch_shape))) d = bounds.shape[-1] - lower = bounds[0] - rng = bounds[1] - bounds[0] sobol_engine = SobolEngine(q * d, scramble=True, seed=seed) - samples_raw = sobol_engine.draw(batch_size * n, dtype=lower.dtype) - samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=lower.device) + samples_raw = sobol_engine.draw(batch_size * n, dtype=bounds.dtype) + samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=bounds.device) if batch_shape != torch.Size(): samples_raw = samples_raw.permute(-3, *range(len(batch_shape)), -2, -1) - return lower + rng * samples_raw + return unnormalize(samples_raw, bounds, update_constant_bounds=False) def draw_sobol_normal_samples( diff --git a/botorch/utils/transforms.py b/botorch/utils/transforms.py index 01f34c0da4..b354821cfb 100644 --- a/botorch/utils/transforms.py +++ b/botorch/utils/transforms.py @@ -66,17 +66,18 @@ def _update_constant_bounds(bounds: Tensor) -> Tensor: return bounds -def normalize(X: Tensor, bounds: Tensor) -> Tensor: +def normalize(X: Tensor, bounds: Tensor, update_constant_bounds: bool = True) -> Tensor: r"""Min-max normalize X w.r.t. the provided bounds. - NOTE: If the upper and lower bounds are identical for a dimension, that dimension - will not be scaled. Such dimensions will only be shifted as - `new_X[..., i] = X[..., i] - bounds[0, i]`. This avoids division by zero issues. - Args: X: `... x d` tensor of data bounds: `2 x d` tensor of lower and upper bounds for each of the X's d columns. + update_constant_bounds: If `True`, update the constant bounds in order to + avoid division by zero issues. When the upper and lower bounds are + identical for a dimension, that dimension will not be scaled. Such + dimensions will only be shifted as + `new_X[..., i] = X[..., i] - bounds[0, i]`. Returns: A `... x d`-dim tensor of normalized data, given by @@ -89,21 +90,27 @@ def normalize(X: Tensor, bounds: Tensor) -> Tensor: >>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)]) >>> X_normalized = normalize(X, bounds) """ - bounds = _update_constant_bounds(bounds=bounds) + bounds = ( + _update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds + ) return (X - bounds[0]) / (bounds[1] - bounds[0]) -def unnormalize(X: Tensor, bounds: Tensor) -> Tensor: +def unnormalize( + X: Tensor, bounds: Tensor, update_constant_bounds: bool = True +) -> Tensor: r"""Un-normalizes X w.r.t. the provided bounds. - NOTE: If the upper and lower bounds are identical for a dimension, that dimension - will not be scaled. Such dimensions will only be shifted as - `new_X[..., i] = X[..., i] + bounds[0, i]`, matching the behavior of `normalize`. - Args: X: `... x d` tensor of data bounds: `2 x d` tensor of lower and upper bounds for each of the X's d columns. + update_constant_bounds: If `True`, update the constant bounds in order to + avoid division by zero issues. When the upper and lower bounds are + identical for a dimension, that dimension will not be scaled. Such + dimensions will only be shifted as + `new_X[..., i] = X[..., i] + bounds[0, i]`. This is the inverse of + the behavior of `normalize` when `update_constant_bounds=True`. Returns: A `... x d`-dim tensor of unnormalized data, given by @@ -116,7 +123,9 @@ def unnormalize(X: Tensor, bounds: Tensor) -> Tensor: >>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)]) >>> X = unnormalize(X_normalized, bounds) """ - bounds = _update_constant_bounds(bounds=bounds) + bounds = ( + _update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds + ) return X * (bounds[1] - bounds[0]) + bounds[0] diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index 49fa23862f..714c85173b 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -9,9 +9,14 @@ import torch from botorch.models.transforms.outcome import ( + _nanmax, + _nanmin, Bilog, ChainedOutcomeTransform, + HalfRankTransform, + InfeasibleTransform, Log, + LogWarperTransform, OutcomeTransform, Power, Standardize, @@ -51,6 +56,70 @@ def forward(self, Y, Yvar): pass +class TestNanMax(BotorchTestCase): + def test_nanmax_basic(self): + tensor = torch.tensor([1.0, float("nan"), 3.0, 2.0]) + result = _nanmax(tensor) + expected = torch.tensor(3.0) + self.assertEqual(result, expected) + + def test_nanmax_with_dim(self): + tensor = torch.tensor([[1.0, float("nan")], [3.0, 2.0]]) + result = _nanmax(tensor, dim=1) + expected = torch.tensor([1.0, 3.0]) + self.assertTrue(torch.equal(result.values, expected)) + + def test_nanmax_with_keepdim(self): + tensor = torch.tensor([[1.0, float("nan")], [3.0, 2.0]]) + result = _nanmax(tensor, dim=1, keepdim=True) + expected = torch.tensor([[1.0], [3.0]]) + self.assertTrue(torch.equal(result.values, expected)) + + def test_nanmax_all_nan(self): + tensor = torch.tensor([float("nan"), float("nan")]) + result = _nanmax(tensor) + expected = torch.tensor(torch.finfo(tensor.dtype).min) + self.assertEqual(result, expected) + + def test_nanmax_no_nan(self): + tensor = torch.tensor([1.0, 2.0, 3.0]) + result = _nanmax(tensor) + expected = torch.tensor(3.0) + self.assertEqual(result, expected) + + +class TestNanMin(BotorchTestCase): + def test_nanmin_basic(self): + tensor = torch.tensor([1.0, float("nan"), 3.0, 2.0]) + result = _nanmin(tensor) + expected = torch.tensor(1.0) + self.assertEqual(result, expected) + + def test_nanmin_with_dim(self): + tensor = torch.tensor([[1.0, float("nan")], [3.0, 2.0]]) + result = _nanmin(tensor, dim=1) + expected = torch.tensor([1.0, 2.0]) + self.assertTrue(torch.equal(result.values, expected)) + + def test_nanmin_with_keepdim(self): + tensor = torch.tensor([[1.0, float("nan")], [3.0, 2.0]]) + result = _nanmin(tensor, dim=1, keepdim=True) + expected = torch.tensor([[1.0], [2.0]]) + self.assertTrue(torch.equal(result.values, expected)) + + def test_nanmin_all_nan(self): + tensor = torch.tensor([float("nan"), float("nan")]) + result = _nanmin(tensor) + expected = torch.tensor(torch.finfo(tensor.dtype).max) + self.assertEqual(result, expected) + + def test_nanmin_no_nan(self): + tensor = torch.tensor([1.0, 2.0, 3.0]) + result = _nanmin(tensor) + expected = torch.tensor(1.0) + self.assertEqual(result, expected) + + class TestOutcomeTransforms(BotorchTestCase): def test_abstract_base_outcome_transform(self): with self.assertRaises(TypeError): @@ -817,3 +886,336 @@ def test_bilog(self, seed=0): Y_tf_subset, Yvar_tf_subset = tf_subset(Y[..., [0]], None) self.assertTrue(torch.equal(Y_tf_subset, Y_tf[..., [0]])) self.assertIsNone(Yvar_tf_subset) + + +class TestInfeasibleTransform(BotorchTestCase): + def test_infeasible_transform_init(self): + """Test initialization of InfeasibleTransform.""" + batch_shape = torch.Size([2, 3]) + transform = InfeasibleTransform(batch_shape=batch_shape) + self.assertEqual(transform._batch_shape, batch_shape) + self.assertFalse(transform._is_trained) + self.assertIsNone(transform._shift) + self.assertTrue(torch.isnan(transform.warped_bad_value)) + + def test_infeasible_transform_forward(self): + """Test forward transformation with NaN values.""" + batch_shape = torch.Size([2]) + transform = InfeasibleTransform(batch_shape=batch_shape) + + # Create test data with NaN values + Y = torch.randn(*batch_shape, 3, 2) + Y[..., 0, 0] = float("nan") + Y_orig = Y.clone() + + # Test forward pass in training mode + transform.train() + Y_tf, _ = transform.forward(Y, None) + + # Check that transform is now trained + self.assertTrue(transform._is_trained) + self.assertIsNotNone(transform._shift) + self.assertFalse(torch.isnan(transform.warped_bad_value).all()) + + # Check that NaN values are replaced with warped_bad_value + self.assertFalse(torch.isnan(Y_tf).any()) + + # Test forward pass in eval mode + transform.eval() + Y_tf_eval, _ = transform.forward(Y_orig, None) + + # Check that NaN values are replaced consistently + self.assertFalse(torch.isnan(Y_tf_eval).any()) + + def test_infeasible_transform_untransform(self): + """Test untransform functionality.""" + transform = InfeasibleTransform(batch_shape=torch.Size([])) + + # Should raise error if not trained + with self.assertRaises(RuntimeError): + transform.untransform(torch.tensor([1.0, 2.0]), None) + + # Train the transform first + batch_shape = torch.Size([2]) + transform = InfeasibleTransform(batch_shape=batch_shape) + Y = torch.randn(*batch_shape, 3, 2) + Y[..., 0, 0] = float("nan") + + transform.train() + Y_tf, Yvar_tf = transform.forward(Y, Y + 2) + self.assertTrue(torch.allclose(Yvar_tf[:, 1:], Y[:, 1:] + 2)) + + # Test untransform + Y_untf, Yvar_untf = transform.untransform(Y_tf, Yvar_tf) + + # Check that values are properly untransformed + self.assertTrue(torch.allclose(Y_untf[:, 1:], Y[:, 1:], rtol=1e-4)) + self.assertTrue(torch.allclose(Yvar_untf[:, 1:], Yvar_tf[:, 1:], rtol=1e-4)) + # test the unwarped_bad_value + self.assertTrue( + torch.allclose( + transform.warped_bad_value[:, 0] - transform._shift[:, 0], + Y_untf[..., 0, 0], + ) + ) + + def test_infeasible_transform_batch_shape_validation(self): + """Test batch shape validation.""" + transform = InfeasibleTransform(batch_shape=torch.Size([2])) + + # Wrong batch shape should raise error + with self.assertRaises(RuntimeError): + transform.forward(torch.randn(3, 4, 2), None) + + def test_infeasible_transform_empty_input(self): + """Test handling of empty input.""" + transform = InfeasibleTransform(batch_shape=torch.Size([])) + + # Empty input should raise error + with self.assertRaises(ValueError): + transform.forward(torch.tensor([]).reshape(0, 1), None) + + def test_infeasible_transform_all_nan(self): + """Test handling of all-NaN input.""" + transform = InfeasibleTransform(batch_shape=torch.Size([])) + + Y = torch.tensor([[float("nan"), float("nan")]]) + transform.train() + with self.assertRaises(RuntimeError): + transform.forward(Y, None) + + def test_infeasible_transform_no_nan(self): + """Test handling of input with no NaN values.""" + transform = InfeasibleTransform(batch_shape=torch.Size([])) + + Y = torch.tensor([[1.0, 2.0, 3.0]]) + transform.train() + Y_tf, _ = transform.forward(Y, None) + + # Check that transformation preserves finite values + assert not torch.isnan(Y_tf).any() + Y_untf, _ = transform.untransform(Y_tf, None) + assert torch.allclose(Y_untf, Y, rtol=1e-4) + + +class TestLogWarperTransform(BotorchTestCase): + def test_log_warper_transform_init(self): + """Test initialization of LogWarperTransform.""" + batch_shape = torch.Size([2, 3]) + transform = LogWarperTransform(offset=2.0, batch_shape=batch_shape) + self.assertEqual(transform._batch_shape, batch_shape) + self.assertEqual(transform.offset.item(), 2.0) + + # Test invalid offset + with self.assertRaisesRegex(ValueError, "offset must be positive"): + LogWarperTransform(offset=0.0) + with self.assertRaisesRegex(ValueError, "offset must be positive"): + LogWarperTransform(offset=-1.0) + + def test_log_warper_transform_forward(self): + """Test forward transformation.""" + batch_shape = torch.Size([2]) + transform = LogWarperTransform(offset=2.0, batch_shape=batch_shape) + + # Create test data with NaN values + Y = torch.randn(*batch_shape, 3, 2) + Y[..., 0, 0] = float("nan") + Y_orig = Y.clone() + + # Test forward pass in training mode + transform.train() + Y_tf, _ = transform.forward(Y, None) + + # Check that transform is now trained + labels_min = transform._labels_min.clone() + labels_max = transform._labels_max.clone() + + self.assertTrue(transform._is_trained) + self.assertTrue(torch.isfinite(labels_min).all()) + self.assertTrue(torch.isfinite(labels_max).all()) + self.assertTrue((torch.isnan(Y_tf) == torch.isnan(Y_orig)).all()) + + # Test forward pass in eval mode + transform.eval() + Y_tf_eval, _ = transform.forward(Y_tf, None) + + # Check that NaN values are replaced consistently + self.assertTrue((torch.isnan(Y_tf_eval) == torch.isnan(Y_tf)).all()) + self.assertTrue(torch.allclose(labels_min, transform._labels_min)) + self.assertTrue(torch.allclose(labels_max, transform._labels_max)) + + def test_log_warper_transform_untransform(self): + """Test untransform functionality.""" + batch_shape = torch.Size([2]) + transform = LogWarperTransform(offset=2.0, batch_shape=batch_shape) + + # Should raise error if not trained + with self.assertRaises(RuntimeError): + transform.untransform(torch.tensor([1.0, 2.0]), None) + + # Train the transform first + Y = torch.randn(*batch_shape, 3, 2) + Y[..., 0, 0] = float("nan") + + transform.train() + Y_tf, _ = transform.forward(Y, None) + + # Test untransform + Y_untf, _ = transform.untransform(Y_tf, None) + + # Check that values are properly untransformed + self.assertTrue(torch.allclose(Y_untf[:, 1:], Y[:, 1:], rtol=1e-4)) + + # test the nan values don't change + self.assertTrue(torch.isnan(Y_untf[..., 0, 0]).all()) + + def test_log_warper_transform_batch_shape_validation(self): + """Test batch shape validation.""" + transform = LogWarperTransform(offset=2.0, batch_shape=torch.Size([2])) + + # Wrong batch shape should raise error + with self.assertRaises(RuntimeError): + transform.forward(torch.randn(3, 4, 2), None) + + def test_log_warper_transform_empty_input(self): + """Test handling of empty input.""" + transform = LogWarperTransform(offset=2.0, batch_shape=torch.Size([])) + + # Empty input should raise error + with self.assertRaises(ValueError): + transform.forward(torch.tensor([]).reshape(0, 1), None) + + +class TestHalfRankTransform(BotorchTestCase): + def test_init(self): + # Test initialization + transform = HalfRankTransform() + self.assertEqual(transform._batch_shape, torch.Size([])) + self.assertFalse(transform._is_trained) + self.assertEqual(transform._unique_labels, {}) + self.assertEqual(transform._warped_labels, {}) + + # Test with batch shape + batch_shape = torch.Size([2, 3]) + transform = HalfRankTransform(batch_shape=batch_shape) + self.assertEqual(transform._batch_shape, batch_shape) + + def test_transform_simple_case(self): + # Test with simple 1D tensor + transform = HalfRankTransform() + Y = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]).reshape(-1, 1) + Y_transformed, _ = transform.forward(Y) + + # Values above median should remain unchanged + self.assertTrue( + torch.allclose(Y_transformed[Y.squeeze() > 3.0], Y[Y.squeeze() > 3.0]) + ) + + # Check if transform is trained + self.assertTrue(transform._is_trained) + + # Test untransform + Y_untransformed, _ = transform.untransform(Y_transformed) + self.assertTrue(torch.allclose(Y_untransformed, Y, rtol=1e-4)) + + def test_transform_with_nans(self): + transform = HalfRankTransform() + Y = torch.tensor([1.0, float("nan"), 3.0, 4.0, 5.0]).reshape(-1, 1) + Y_transformed, _ = transform.forward(Y) + + # NaN values should remain NaN + self.assertTrue(torch.isnan(Y_transformed[torch.isnan(Y)]).all()) + + # Non-NaN values above median should remain unchanged + valid_mask = ~torch.isnan(Y.squeeze()) + median = torch.nanmedian(Y) + self.assertTrue( + torch.allclose( + Y_transformed[valid_mask & (Y.squeeze() > median)], + Y[valid_mask & (Y.squeeze() > median)], + ) + ) + + def test_transform_batch(self): + batch_shape = torch.Size([2]) + transform = HalfRankTransform(batch_shape=batch_shape) + Y = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).reshape(2, 3, 1) + Y_transformed, _ = transform.forward(Y) + + # Shape should be preserved + self.assertEqual(Y_transformed.shape, Y.shape) + + # Test untransform + Y_untransformed, _ = transform.untransform(Y_transformed) + self.assertTrue(torch.allclose(Y_untransformed, Y, rtol=1e-4)) + + def test_transform_multi_output(self): + transform = HalfRankTransform() + Y = torch.tensor([[1.0, 10.0], [2.0, 20.0], [3.0, 30.0], [4.0, 40.0]]) + Y_transformed, _ = transform.forward(Y) + + # Each output dimension should be transformed independently + self.assertEqual(Y_transformed.shape, Y.shape) + + # Test untransform + Y_untransformed, _ = transform.untransform(Y_transformed) + self.assertTrue(torch.allclose(Y_untransformed, Y, rtol=1e-4)) + + def test_error_cases(self): + transform = HalfRankTransform() + + # Test all NaN case + Y = torch.tensor([[float("nan")], [float("nan")]]) + with self.assertRaisesRegex( + RuntimeError, "For at least one batch, all outcomes are NaN" + ): + transform.forward(Y) + + # Test untransform before training + Y = torch.tensor([[1.0], [2.0]]) + with self.assertRaisesRegex( + RuntimeError, "needs to be called before untransform" + ): + transform.untransform(Y) + + # Test with observation noise + Y = torch.tensor([[1.0], [2.0]]) + Yvar = torch.tensor([[0.1], [0.1]]) + with self.assertRaisesRegex( + NotImplementedError, + "HalfRankTransform does not support transforming observation noise", + ): + transform.forward(Y, Yvar) + + def test_batch_shape_mismatch(self): + batch_shape = torch.Size([2]) + transform = HalfRankTransform(batch_shape=batch_shape) + Y = torch.tensor([[1.0], [2.0], [3.0]]) # Wrong batch shape + with self.assertRaises(RuntimeError): + transform.forward(Y) + + def test_extrapolation(self): + transform = HalfRankTransform() + Y = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]).reshape(-1, 1) + Y_transformed, _ = transform.forward(Y) + + # Test extrapolation below minimum + Y_test = torch.tensor([0.0]).reshape(-1, 1) + Y_test_transformed, _ = transform.forward(Y_test) + Y_test_untransformed, _ = transform.untransform(Y_test_transformed) + + # The untransformed value should be close to but below the minimum + self.assertLess(Y_test_untransformed.item(), Y.min()) + + def test_interpolation(self): + transform = HalfRankTransform() + Y = torch.tensor([1.0, 3.0, 5.0]).reshape(-1, 1) + Y_transformed, _ = transform.forward(Y) + + # Test interpolation between values + Y_test = torch.tensor([2.0]).reshape(-1, 1) + Y_test_transformed, _ = transform.forward(Y_test) + Y_test_untransformed, _ = transform.untransform(Y_test_transformed) + + # The untransformed value should be close to the original + self.assertTrue(torch.allclose(Y_test_untransformed, Y_test, rtol=1e-4)) diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 09be6f2326..902ecfc449 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -30,12 +30,14 @@ from botorch.exceptions.warnings import BotorchWarning from botorch.models import SingleTaskGP from botorch.models.model_list_gp_regression import ModelListGP -from botorch.optim import initialize_q_batch, initialize_q_batch_nonneg from botorch.optim.initializers import ( gen_batch_initial_conditions, gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, gen_value_function_initial_conditions, + initialize_q_batch, + initialize_q_batch_nonneg, + initialize_q_batch_topn, sample_perturbed_subset_dims, sample_points_around_best, sample_q_batches_from_polytope, @@ -45,7 +47,7 @@ transform_intra_point_constraint, ) from botorch.sampling.normal import IIDNormalSampler -from botorch.utils.sampling import draw_sobol_samples, manual_seed +from botorch.utils.sampling import draw_sobol_samples, manual_seed, unnormalize from botorch.utils.testing import ( _get_max_violation_of_bounds, _get_max_violation_of_constraints, @@ -129,31 +131,68 @@ def test_initialize_q_batch_nonneg(self): self.assertEqual(ics.dtype, X.dtype) def test_initialize_q_batch(self): + for dtype, batch_shape in ( + (torch.float, torch.Size()), + (torch.double, [3, 2]), + (torch.float, (2,)), + (torch.double, torch.Size([2, 3, 4])), + (torch.float, []), + ): + # basic test + X = torch.rand(5, *batch_shape, 3, 4, device=self.device, dtype=dtype) + acq_vals = torch.rand(5, *batch_shape, device=self.device, dtype=dtype) + ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics_X.shape, torch.Size([2, *batch_shape, 3, 4])) + self.assertEqual(ics_X.device, X.device) + self.assertEqual(ics_X.dtype, X.dtype) + self.assertEqual(ics_acq_vals.shape, torch.Size([2, *batch_shape])) + self.assertEqual(ics_acq_vals.device, acq_vals.device) + self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) + # ensure nothing happens if we want all samples + ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=5) + self.assertTrue(torch.equal(X, ics_X)) + self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) + # ensure raises correct warning + acq_vals = torch.zeros(5, device=self.device, dtype=dtype) + with warnings.catch_warnings(record=True) as w: + ics, _ = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) + self.assertEqual(ics.shape, torch.Size([2, *batch_shape, 3, 4])) + with self.assertRaises(RuntimeError): + initialize_q_batch(X=X, acq_vals=acq_vals, n=10) + + def test_initialize_q_batch_topn(self): for dtype in (torch.float, torch.double): - for batch_shape in (torch.Size(), [3, 2], (2,), torch.Size([2, 3, 4]), []): - # basic test - X = torch.rand(5, *batch_shape, 3, 4, device=self.device, dtype=dtype) - acq_vals = torch.rand(5, *batch_shape, device=self.device, dtype=dtype) - ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) - self.assertEqual(ics_X.shape, torch.Size([2, *batch_shape, 3, 4])) - self.assertEqual(ics_X.device, X.device) - self.assertEqual(ics_X.dtype, X.dtype) - self.assertEqual(ics_acq_vals.shape, torch.Size([2, *batch_shape])) - self.assertEqual(ics_acq_vals.device, acq_vals.device) - self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) - # ensure nothing happens if we want all samples - ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=5) - self.assertTrue(torch.equal(X, ics_X)) - self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) - # ensure raises correct warning - acq_vals = torch.zeros(5, device=self.device, dtype=dtype) - with warnings.catch_warnings(record=True) as w: - ics, _ = initialize_q_batch(X=X, acq_vals=acq_vals, n=2) - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) - self.assertEqual(ics.shape, torch.Size([2, *batch_shape, 3, 4])) - with self.assertRaises(RuntimeError): - initialize_q_batch(X=X, acq_vals=acq_vals, n=10) + # basic test + X = torch.rand(5, 3, 4, device=self.device, dtype=dtype) + acq_vals = torch.rand(5, device=self.device, dtype=dtype) + ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics_X.shape, torch.Size([2, 3, 4])) + self.assertEqual(ics_X.device, X.device) + self.assertEqual(ics_X.dtype, X.dtype) + self.assertEqual(ics_acq_vals.shape, torch.Size([2])) + self.assertEqual(ics_acq_vals.device, acq_vals.device) + self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype) + # ensure nothing happens if we want all samples + ics_X, ics_acq_vals = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=5) + self.assertTrue(torch.equal(X, ics_X)) + self.assertTrue(torch.equal(acq_vals, ics_acq_vals)) + # make sure things work with constant inputs + acq_vals = torch.ones(5, device=self.device, dtype=dtype) + ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(ics.shape, torch.Size([2, 3, 4])) + self.assertEqual(ics.device, X.device) + self.assertEqual(ics.dtype, X.dtype) + # ensure raises correct warning + acq_vals = torch.zeros(5, device=self.device, dtype=dtype) + with warnings.catch_warnings(record=True) as w: + ics, _ = initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=2) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning)) + self.assertEqual(ics.shape, torch.Size([2, 3, 4])) + with self.assertRaises(RuntimeError): + initialize_q_batch_topn(X=X, acq_vals=acq_vals, n=10) def test_initialize_q_batch_largeZ(self): for dtype in (torch.float, torch.double): @@ -187,114 +226,197 @@ def test_gen_batch_initial_conditions(self): bounds = torch.stack([torch.zeros(2), torch.ones(2)]) mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) - for dtype in (torch.float, torch.double): + for ( + dtype, + nonnegative, + seed, + init_batch_limit, + ffs, + sample_around_best, + ) in ( + (torch.float, True, None, None, None, True), + (torch.double, False, 1234, 1, {0: 0.5}, False), + (torch.double, True, 1234, None, {0: 0.5}, True), + ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) - for nonnegative, seed, init_batch_limit, ffs, sample_around_best in product( - [True, False], [None, 1234], [None, 1], [None, {0: 0.5}], [True, False] - ): - with mock.patch.object( - MockAcquisitionFunction, - "__call__", - wraps=mock_acqf.__call__, - ) as mock_acqf_call, warnings.catch_warnings(): - warnings.simplefilter( - "ignore", category=BadInitialCandidatesWarning - ) - batch_initial_conditions = gen_batch_initial_conditions( - acq_function=mock_acqf, - bounds=bounds, - q=1, - num_restarts=2, - raw_samples=10, - fixed_features=ffs, - options={ - "nonnegative": nonnegative, - "eta": 0.01, - "alpha": 0.1, - "seed": seed, - "init_batch_limit": init_batch_limit, - "sample_around_best": sample_around_best, - }, - ) - expected_shape = torch.Size([2, 1, 2]) - self.assertEqual(batch_initial_conditions.shape, expected_shape) - self.assertEqual(batch_initial_conditions.device, bounds.device) - self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) - self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), - 1e-6, - ) - batch_shape = ( - torch.Size([]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - raw_samps = mock_acqf_call.call_args[0][0] - batch_shape = ( - torch.Size([20 if sample_around_best else 10]) - if init_batch_limit is None - else torch.Size([init_batch_limit]) - ) - expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) - self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + with mock.patch.object( + MockAcquisitionFunction, + "__call__", + wraps=mock_acqf.__call__, + ) as mock_acqf_call, warnings.catch_warnings(): + warnings.simplefilter("ignore", category=BadInitialCandidatesWarning) + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=mock_acqf, + bounds=bounds, + q=1, + num_restarts=2, + raw_samples=10, + fixed_features=ffs, + options={ + "nonnegative": nonnegative, + "eta": 0.01, + "alpha": 0.1, + "seed": seed, + "init_batch_limit": init_batch_limit, + "sample_around_best": sample_around_best, + }, + ) + expected_shape = torch.Size([2, 1, 2]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) - if ffs is not None: - for idx, val in ffs.items(): - self.assertTrue( - torch.all(batch_initial_conditions[..., idx] == val) - ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) - def test_gen_batch_initial_conditions_highdim(self): - d = 2200 # 2200 * 10 (q) > 21201 (sobol max dim) - bounds = torch.stack([torch.zeros(d), torch.ones(d)]) - ffs_map = {i: random() for i in range(0, d, 2)} + def test_gen_batch_initial_conditions_topn(self): + bounds = torch.stack([torch.zeros(2), torch.ones(2)]) mock_acqf = MockAcquisitionFunction() mock_acqf.objective = lambda y: y.squeeze(-1) - for dtype in (torch.float, torch.double): + mock_acqf.maximize = True # Add maximize attribute + for ( + dtype, + topn, + largest, + is_sorted, + seed, + init_batch_limit, + ffs, + sample_around_best, + ) in ( + (torch.float, True, True, True, None, None, None, True), + (torch.double, False, False, False, 1234, 1, {0: 0.5}, False), + (torch.float, True, None, True, 1234, None, None, False), + (torch.double, False, True, False, None, 1, {0: 0.5}, True), + (torch.float, True, False, False, 1234, None, {0: 0.5}, True), + (torch.double, False, None, True, None, 1, None, False), + (torch.float, True, True, False, 1234, 1, {0: 0.5}, True), + (torch.double, False, False, True, None, None, None, False), + ): bounds = bounds.to(device=self.device, dtype=dtype) mock_acqf.X_baseline = bounds # for testing sample_around_best mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) - - for nonnegative, seed, ffs, sample_around_best in product( - [True, False], [None, 1234], [None, ffs_map], [True, False] - ): - with warnings.catch_warnings(record=True) as ws: - warnings.simplefilter( - "ignore", category=BadInitialCandidatesWarning - ) - batch_initial_conditions = gen_batch_initial_conditions( - acq_function=MockAcquisitionFunction(), - bounds=bounds, - q=10, - num_restarts=1, - raw_samples=2, - fixed_features=ffs, - options={ - "nonnegative": nonnegative, - "eta": 0.01, - "alpha": 0.1, - "seed": seed, - "sample_around_best": sample_around_best, - }, - ) - self.assertTrue( - any(issubclass(w.category, SamplingWarning) for w in ws) - ) - expected_shape = torch.Size([1, 10, d]) + with mock.patch.object( + MockAcquisitionFunction, + "__call__", + wraps=mock_acqf.__call__, + ) as mock_acqf_call, warnings.catch_warnings(): + warnings.simplefilter("ignore", category=BadInitialCandidatesWarning) + options = { + "topn": topn, + "sorted": is_sorted, + "seed": seed, + "init_batch_limit": init_batch_limit, + "sample_around_best": sample_around_best, + } + if largest is not None: + options["largest"] = largest + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=mock_acqf, + bounds=bounds, + q=1, + num_restarts=2, + raw_samples=10, + fixed_features=ffs, + options=options, + ) + expected_shape = torch.Size([2, 1, 2]) self.assertEqual(batch_initial_conditions.shape, expected_shape) self.assertEqual(batch_initial_conditions.device, bounds.device) self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) self.assertLess( - _get_max_violation_of_bounds(batch_initial_conditions, bounds), 1e-6 + _get_max_violation_of_bounds(batch_initial_conditions, bounds), + 1e-6, + ) + batch_shape = ( + torch.Size([]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) ) + raw_samps = mock_acqf_call.call_args[0][0] + batch_shape = ( + torch.Size([20 if sample_around_best else 10]) + if init_batch_limit is None + else torch.Size([init_batch_limit]) + ) + expected_raw_samps_shape = batch_shape + torch.Size([1, 2]) + self.assertEqual(raw_samps.shape, expected_raw_samps_shape) + if ffs is not None: for idx, val in ffs.items(): self.assertTrue( torch.all(batch_initial_conditions[..., idx] == val) ) + def test_gen_batch_initial_conditions_highdim(self): + d = 2200 # 2200 * 10 (q) > 21201 (sobol max dim) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + ffs_map = {i: random() for i in range(0, d, 2)} + mock_acqf = MockAcquisitionFunction() + mock_acqf.objective = lambda y: y.squeeze(-1) + for dtype, nonnegative, seed, ffs, sample_around_best in ( + (torch.float, True, None, None, True), + (torch.double, False, 1234, ffs_map, False), + (torch.double, True, 1234, ffs_map, True), + ): + bounds = bounds.to(device=self.device, dtype=dtype) + mock_acqf.X_baseline = bounds # for testing sample_around_best + mock_acqf.model = MockModel(MockPosterior(mean=bounds[:, :1])) + with warnings.catch_warnings(record=True) as ws: + warnings.simplefilter("ignore", category=BadInitialCandidatesWarning) + batch_initial_conditions = gen_batch_initial_conditions( + acq_function=MockAcquisitionFunction(), + bounds=bounds, + q=10, + num_restarts=1, + raw_samples=2, + fixed_features=ffs, + options={ + "nonnegative": nonnegative, + "eta": 0.01, + "alpha": 0.1, + "seed": seed, + "sample_around_best": sample_around_best, + }, + ) + self.assertTrue( + any(issubclass(w.category, SamplingWarning) for w in ws) + ) + expected_shape = torch.Size([1, 10, d]) + self.assertEqual(batch_initial_conditions.shape, expected_shape) + self.assertEqual(batch_initial_conditions.device, bounds.device) + self.assertEqual(batch_initial_conditions.dtype, bounds.dtype) + self.assertLess( + _get_max_violation_of_bounds(batch_initial_conditions, bounds), 1e-6 + ) + if ffs is not None: + for idx, val in ffs.items(): + self.assertTrue( + torch.all(batch_initial_conditions[..., idx] == val) + ) + def test_gen_batch_initial_conditions_warning(self) -> None: for dtype in (torch.float, torch.double): bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype) @@ -727,7 +849,9 @@ def generator(n: int, q: int, seed: int | None): dtype=bounds.dtype, device=self.device, ) - X_rnd = bounds[0] + (bounds[1] - bounds[0]) * X_rnd_nlzd + X_rnd = unnormalize( + X_rnd_nlzd, bounds, update_constant_bounds=False + ) X_rnd[..., -1] = 0.42 return X_rnd