diff --git a/botorch/acquisition/fixed_feature.py b/botorch/acquisition/fixed_feature.py index 0f3b85faa7..763226799e 100644 --- a/botorch/acquisition/fixed_feature.py +++ b/botorch/acquisition/fixed_feature.py @@ -16,11 +16,11 @@ import torch from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from torch import Tensor -from torch.nn import Module -class FixedFeatureAcquisitionFunction(AcquisitionFunction): +class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper): """A wrapper around AquisitionFunctions to fix a subset of features. Example: @@ -56,8 +56,7 @@ def __init__( combination of `Tensor`s and numbers which can be broadcasted to form a tensor with trailing dimension size of `d_f`. """ - Module.__init__(self) - self.acq_func = acq_function + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function) dtype = torch.float device = torch.device("cpu") self.d = d @@ -126,24 +125,13 @@ def forward(self, X: Tensor): X_full = self._construct_X_full(X) return self.acq_func(X_full) - @property - def X_pending(self): - r"""Return the `X_pending` of the base acquisition function.""" - try: - return self.acq_func.X_pending - except (ValueError, AttributeError): - raise ValueError( - f"Base acquisition function {type(self.acq_func).__name__} " - "does not have an `X_pending` attribute." - ) - - @X_pending.setter - def X_pending(self, X_pending: Optional[Tensor]): + def set_X_pending(self, X_pending: Optional[Tensor]): r"""Sets the `X_pending` of the base acquisition function.""" if X_pending is not None: - self.acq_func.X_pending = self._construct_X_full(X_pending) + full_X_pending = self._construct_X_full(X_pending) else: - self.acq_func.X_pending = X_pending + full_X_pending = None + self.acq_func.set_X_pending(full_X_pending) def _construct_X_full(self, X: Tensor) -> Tensor: r"""Constructs the full input for the base acquisition function. diff --git a/botorch/acquisition/penalized.py b/botorch/acquisition/penalized.py index b114362ea9..9ee8f1fee5 100644 --- a/botorch/acquisition/penalized.py +++ b/botorch/acquisition/penalized.py @@ -15,9 +15,8 @@ import torch from botorch.acquisition.acquisition import AcquisitionFunction -from botorch.acquisition.analytic import AnalyticAcquisitionFunction from botorch.acquisition.objective import GenericMCObjective -from botorch.exceptions import UnsupportedError +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from torch import Tensor @@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor: return regularization_term -class PenalizedAcquisitionFunction(AcquisitionFunction): +class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper): r"""Single-outcome acquisition function regularized by the given penalty. The usage is similar to: @@ -161,29 +160,16 @@ def __init__( penalty_func: The regularization function. regularization_parameter: Regularization parameter used in optimization. """ - super().__init__(model=raw_acqf.model) - self.raw_acqf = raw_acqf + AcquisitionFunction.__init__(self, model=raw_acqf.model) + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf) self.penalty_func = penalty_func self.regularization_parameter = regularization_parameter def forward(self, X: Tensor) -> Tensor: - raw_value = self.raw_acqf(X=X) + raw_value = self.acq_func(X=X) penalty_term = self.penalty_func(X) return raw_value - self.regularization_parameter * penalty_term - @property - def X_pending(self) -> Optional[Tensor]: - return self.raw_acqf.X_pending - - def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None: - if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction): - self.raw_acqf.set_X_pending(X_pending=X_pending) - else: - raise UnsupportedError( - "The raw acquisition function is Analytic and does not account " - "for X_pending yet." - ) - def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor: r"""Computes the group lasso regularization function for the given point. diff --git a/botorch/acquisition/probabilistic_reparameterization.py b/botorch/acquisition/probabilistic_reparameterization.py new file mode 100644 index 0000000000..5c6428985e --- /dev/null +++ b/botorch/acquisition/probabilistic_reparameterization.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Probabilistic Reparameterization (with gradients) using Monte Carlo estimators. + +See [Daulton2022bopr]_ for details. +""" + +from contextlib import ExitStack +from typing import Dict, List, Optional + +import torch +from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper +from botorch.models.transforms.factory import ( + get_probabilistic_reparameterization_input_transform, +) + +from botorch.models.transforms.input import ( + ChainedInputTransform, + InputTransform, + OneHotToNumeric, +) +from torch import Tensor +from torch.autograd import Function +from torch.nn.functional import one_hot + + +class _MCProbabilisticReparameterization(Function): + r"""Evaluate the acquisition function via probabistic reparameterization. + + This uses a score function gradient estimator. See [Daulton2022bopr]_ for details. + """ + + @staticmethod + def forward( + ctx, + X: Tensor, + acq_function: AcquisitionFunction, + input_tf: InputTransform, + batch_limit: Optional[int], + integer_indices: Tensor, + cont_indices: Tensor, + categorical_indices: Tensor, + use_ma_baseline: bool, + one_hot_to_numeric: Optional[OneHotToNumeric], + ma_counter: Optional[Tensor], + ma_hidden: Optional[Tensor], + ma_decay: Optional[float], + ): + """Evaluate the expectation of the acquisition function under + probabilistic reparameterization. Compute this in chunks of size + batch_limit to enable scaling to large numbers of samples from the + proposal distribution. + """ + with ExitStack() as es: + if ctx.needs_input_grad[0]: + es.enter_context(torch.enable_grad()) + if cont_indices.shape[0] > 0: + # only require gradient for continuous parameters + ctx.cont_X = X[..., cont_indices].detach().requires_grad_(True) + cont_idx = 0 + cols = [] + for col in range(X.shape[-1]): + # cont_indices is sorted in ascending order + if ( + cont_idx < cont_indices.shape[0] + and col == cont_indices[cont_idx] + ): + cols.append(ctx.cont_X[..., cont_idx]) + cont_idx += 1 + else: + cols.append(X[..., col]) + X = torch.stack(cols, dim=-1) + else: + ctx.cont_X = None + ctx.discrete_indices = input_tf["round"].discrete_indices + ctx.cont_indices = cont_indices + ctx.categorical_indices = categorical_indices + ctx.ma_counter = ma_counter + ctx.ma_hidden = ma_hidden + ctx.X_shape = X.shape + tilde_x_samples = input_tf(X.unsqueeze(-3)) + # save the rounding component + + rounding_component = tilde_x_samples.clone() + if integer_indices.shape[0] > 0: + X_integer_params = X[..., integer_indices].unsqueeze(-3) + rounding_component[..., integer_indices] = ( + (tilde_x_samples[..., integer_indices] - X_integer_params > 0) + | (X_integer_params == 1) + ).to(tilde_x_samples) + if categorical_indices.shape[0] > 0: + rounding_component[..., categorical_indices] = tilde_x_samples[ + ..., categorical_indices + ] + ctx.rounding_component = rounding_component[..., ctx.discrete_indices] + ctx.tau = input_tf["round"].tau + if hasattr(input_tf["round"], "base_samples"): + ctx.base_samples = input_tf["round"].base_samples.detach() + # save the probabilities + if "unnormalize" in input_tf: + unnormalized_X = input_tf["unnormalize"](X) + else: + unnormalized_X = X + # this is only for the integer parameters + ctx.prob = input_tf["round"].get_rounding_prob(unnormalized_X) + + if categorical_indices.shape[0] > 0: + ctx.base_samples_categorical = input_tf[ + "round" + ].base_samples_categorical.clone() + # compute the acquisition function where inputs are rounded according to base_samples < prob + ctx.tilde_x_samples = tilde_x_samples + ctx.use_ma_baseline = use_ma_baseline + acq_values_list = [] + start_idx = 0 + if one_hot_to_numeric is not None: + tilde_x_samples = one_hot_to_numeric(tilde_x_samples) + + while start_idx < tilde_x_samples.shape[-3]: + end_idx = min(start_idx + batch_limit, tilde_x_samples.shape[-3]) + acq_values = acq_function(tilde_x_samples[..., start_idx:end_idx, :, :]) + acq_values_list.append(acq_values) + start_idx += batch_limit + acq_values = torch.cat(acq_values_list, dim=-1) + ctx.mean_acq_values = acq_values.mean( + dim=-1 + ) # average over samples from proposal distribution + ctx.acq_values = acq_values + # update moving average baseline + ctx.ma_hidden = ma_hidden.clone() + ctx.ma_counter = ctx.ma_counter.clone() + ctx.ma_decay = ma_decay + # update in place + ma_counter.add_(1) + ma_hidden.sub_((ma_hidden - acq_values.detach().mean()) * (1 - ma_decay)) + return ctx.mean_acq_values.detach() + + @staticmethod + def backward(ctx, grad_output): + """ + Compute the gradient of the expectation of the acquisition function + with respect to the parameters of the proposal distribution using + Monte Carlo. + """ + # this is overwriting the entire gradient w.r.t. x' + # x' has shape batch_shape x q x d + if ctx.needs_input_grad[0]: + acq_values = ctx.acq_values + mean_acq_values = ctx.mean_acq_values + cont_indices = ctx.cont_indices + discrete_indices = ctx.discrete_indices + rounding_component = ctx.rounding_component + # retrieve only the ordinal parameters + expanded_acq_values = acq_values.view(*acq_values.shape, 1, 1).expand( + acq_values.shape + rounding_component.shape[-2:] + ) + prob = ctx.prob.unsqueeze(-3) + if not ctx.use_ma_baseline: + sample_level = expanded_acq_values * (rounding_component - prob) + else: + # use reinforce with the moving average baseline + if ctx.ma_counter == 0: + baseline = 0.0 + else: + baseline = ctx.ma_hidden / ( + 1.0 - torch.pow(ctx.ma_decay, ctx.ma_counter) + ) + sample_level = (expanded_acq_values - baseline) * ( + rounding_component - prob + ) + + grads = (sample_level / ctx.tau).mean(dim=-3) + + new_grads = ( + grad_output.view( + *grad_output.shape, + *[1 for _ in range(grads.ndim - grad_output.ndim)], + ) + .expand(*grad_output.shape, *ctx.X_shape[-2:]) + .clone() + ) + # multiply upstream grad_output by new gradients + new_grads[..., discrete_indices] *= grads + # use autograd for gradients w.r.t. the continuous parameters + if ctx.cont_X is not None: + auto_grad = torch.autograd.grad( + # note: this multiplies the gradient of mean_acq_values w.r.t to input + # by grad_output + mean_acq_values, + ctx.cont_X, + grad_outputs=grad_output, + )[0] + # overwrite grad_output since the previous step already applied the chain rule + new_grads[..., cont_indices] = auto_grad + return ( + new_grads, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + return None, None, None, None, None, None, None, None, None, None, None, None + + +class AbstractProbabilisticReparameterization(AbstractAcquisitionFunctionWrapper): + r"""Acquisition Function Wrapper that leverages probabilistic reparameterization. + + The forward method is abstract and must be implemented. + + See [Daulton2022bopr]_ for details. + """ + + input_transform: ChainedInputTransform + + def __init__( + self, + acq_function: AcquisitionFunction, + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + batch_limit: int = 32, + apply_numeric: bool = False, + **kwargs, + ) -> None: + r"""Initialize probabilistic reparameterization (PR). + + Args: + acq_function: The acquisition function. + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + batch_limit: The chunk size used in evaluating PR to limit memory + overhead. + apply_numeric: A boolean indicated if categoricals should be supplied + to the underlying acquisition function in numeric representation. + """ + if categorical_features is None and integer_indices is None: + raise NotImplementedError( + "categorical_features or integer indices must be provided." + ) + super().__init__(acq_function=acq_function) + self.batch_limit = batch_limit + + if apply_numeric: + self.one_hot_to_numeric = OneHotToNumeric( + categorical_features=categorical_features, + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=False, + ) + self.one_hot_to_numeric.eval() + else: + self.one_hot_to_numeric = None + discrete_indices = [] + if integer_indices is not None: + self.register_buffer( + "integer_indices", + torch.tensor( + integer_indices, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer("integer_bounds", one_hot_bounds[:, integer_indices]) + discrete_indices.extend(integer_indices) + else: + self.register_buffer( + "integer_indices", + torch.tensor([], dtype=torch.long, device=one_hot_bounds.device), + ) + self.register_buffer( + "integer_bounds", + torch.tensor( + [], dtype=one_hot_bounds.dtype, device=one_hot_bounds.device + ), + ) + dim = one_hot_bounds.shape[1] + if categorical_features is not None and len(categorical_features) > 0: + categorical_indices = list(range(min(categorical_features.keys()), dim)) + discrete_indices.extend(categorical_indices) + self.register_buffer( + "categorical_indices", + torch.tensor( + categorical_indices, + dtype=torch.long, + device=one_hot_bounds.device, + ), + ) + self.categorical_features = categorical_features + else: + self.register_buffer( + "categorical_indices", + torch.tensor( + [], + dtype=torch.long, + device=one_hot_bounds.device, + ), + ) + + self.register_buffer( + "cont_indices", + torch.tensor( + sorted(set(range(dim)) - set(discrete_indices)), + dtype=torch.long, + device=one_hot_bounds.device, + ), + ) + self.model = acq_function.model # for sample_around_best heuristic + # moving average baseline + self.register_buffer( + "ma_counter", + torch.zeros(1, dtype=one_hot_bounds.dtype, device=one_hot_bounds.device), + ) + self.register_buffer( + "ma_hidden", + torch.zeros(1, dtype=one_hot_bounds.dtype, device=one_hot_bounds.device), + ) + self.register_buffer( + "ma_baseline", + torch.zeros(1, dtype=one_hot_bounds.dtype, device=one_hot_bounds.device), + ) + + def sample_candidates(self, X: Tensor) -> Tensor: + if "unnormalize" in self.input_transform: + unnormalized_X = self.input_transform["unnormalize"](X) + else: + unnormalized_X = X.clone() + prob = self.input_transform["round"].get_rounding_prob(X=unnormalized_X) + discrete_idx = 0 + for i in self.integer_indices: + p = prob[..., discrete_idx] + rounding_component = torch.distributions.Bernoulli(probs=p).sample() + unnormalized_X[..., i] = unnormalized_X[..., i].floor() + rounding_component + discrete_idx += 1 + if len(self.integer_indices) > 0: + unnormalized_X[..., self.integer_indices] = torch.minimum( + torch.maximum( + unnormalized_X[..., self.integer_indices], self.integer_bounds[0] + ), + self.integer_bounds[1], + ) + # this is the starting index for the categoricals in unnormalized_X + raw_idx = self.cont_indices.shape[0] + discrete_idx + if self.categorical_indices.shape[0] > 0: + for cardinality in self.categorical_features.values(): + discrete_end = discrete_idx + cardinality + p = prob[..., discrete_idx:discrete_end] + z = one_hot( + torch.distributions.Categorical(probs=p).sample(), + num_classes=cardinality, + ) + raw_end = raw_idx + cardinality + unnormalized_X[..., raw_idx:raw_end] = z + discrete_idx = discrete_end + raw_idx = raw_end + # normalize X + if "normalize" in self.input_transform: + return self.input_transform["normalize"](unnormalized_X) + return unnormalized_X + + +class AnalyticProbabilisticReparameterization(AbstractProbabilisticReparameterization): + """Analytic probabilistic reparameterization. + + Note: this is only reasonable from a computation perspective for relatively + small numbers of discrete options (probably less than a few thousand). + """ + + def __init__( + self, + acq_function: AcquisitionFunction, + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + batch_limit: int = 32, + apply_numeric: bool = False, + tau: float = 0.1, + ) -> None: + """Initialize probabilistic reparameterization (PR). + + Args: + acq_function: The acquisition function. + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + batch_limit: The chunk size used in evaluating PR to limit memory + overhead. + apply_numeric: A boolean indicated if categoricals should be supplied + to the underlying acquisition function in numeric representation. + tau: The temperature parameter used to determine the probabilities. + + """ + super().__init__( + acq_function=acq_function, + integer_indices=integer_indices, + one_hot_bounds=one_hot_bounds, + categorical_features=categorical_features, + batch_limit=batch_limit, + apply_numeric=apply_numeric, + ) + # create input transform + # need to compute cross product of discrete options and weights + self.input_transform = get_probabilistic_reparameterization_input_transform( + one_hot_bounds=one_hot_bounds, + use_analytic=True, + integer_indices=integer_indices, + categorical_features=categorical_features, + tau=tau, + ) + + def forward(self, X: Tensor) -> Tensor: + r"""Evaluate PR.""" + X_discrete_all = self.input_transform(X.unsqueeze(-3)) + acq_values_list = [] + start_idx = 0 + if self.one_hot_to_numeric is not None: + X_discrete_all = self.one_hot_to_numeric(X_discrete_all) + if X.shape[-2] != 1: + raise NotImplementedError + + # save the probabilities + if "unnormalize" in self.input_transform: + unnormalized_X = self.input_transform["unnormalize"](X) + else: + unnormalized_X = X + # this is batch_shape x n_discrete (after squeezing) + probs = self.input_transform["round"].get_probs(X=unnormalized_X).squeeze(-1) + # TODO: filter discrete configs with zero probability + # this would require padding because there may be a different number in each batch. + while start_idx < X_discrete_all.shape[-3]: + end_idx = min(start_idx + self.batch_limit, X_discrete_all.shape[-3]) + acq_values = self.acq_func(X_discrete_all[..., start_idx:end_idx, :, :]) + acq_values_list.append(acq_values) + start_idx += self.batch_limit + # this is batch_shape x n_discrete + acq_values = torch.cat(acq_values_list, dim=-1) + # now weight the acquisition values by probabilities + return (acq_values * probs).sum(dim=-1) + + +class MCProbabilisticReparameterization(AbstractProbabilisticReparameterization): + r"""MC-based probabilistic reparameterization. + + See [Daulton2022bopr]_ for details. + """ + + def __init__( + self, + acq_function: AcquisitionFunction, + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + batch_limit: int = 32, + apply_numeric: bool = False, + mc_samples: int = 128, + use_ma_baseline: bool = True, + tau: float = 0.1, + ma_decay: float = 0.7, + resample: bool = True, + ) -> None: + """Initialize probabilistic reparameterization (PR). + + Args: + acq_function: The acquisition function. + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + batch_limit: The chunk size used in evaluating PR to limit memory + overhead. + apply_numeric: A boolean indicated if categoricals should be supplied + to the underlying acquisition function in numeric representation. + mc_samples: The number of MC samples for MC probabilistic + reparameterization. + use_ma_baseline: A boolean indicating whether to use a moving average + baseline for variance reduction. + tau: The temperature parameter used to determine the probabilities. + ma_decay: The decay parameter in the moving average baseline. + Default: 0.7 + resample: A boolean indicating whether to resample with MC + probabilistic reparameterization on each forward pass. + + """ + super().__init__( + acq_function=acq_function, + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + batch_limit=batch_limit, + apply_numeric=apply_numeric, + ) + if self.batch_limit is None: + self.batch_limit = mc_samples + self.use_ma_baseline = use_ma_baseline + self._pr_acq_function = _MCProbabilisticReparameterization() + # create input transform + self.input_transform = get_probabilistic_reparameterization_input_transform( + integer_indices=integer_indices, + one_hot_bounds=one_hot_bounds, + categorical_features=categorical_features, + mc_samples=mc_samples, + tau=tau, + resample=resample, + ) + self.ma_decay = ma_decay + + def forward(self, X: Tensor) -> Tensor: + r"""Evaluate MC probabilistic reparameterization.""" + return self._pr_acq_function.apply( + X, + self.acq_func, + self.input_transform, + self.batch_limit, + self.integer_indices, + self.cont_indices, + self.categorical_indices, + self.use_ma_baseline, + self.one_hot_to_numeric, + self.ma_counter, + self.ma_hidden, + self.ma_decay, + ) diff --git a/botorch/acquisition/proximal.py b/botorch/acquisition/proximal.py index 9cd4aed7ad..b1d68edef1 100644 --- a/botorch/acquisition/proximal.py +++ b/botorch/acquisition/proximal.py @@ -15,6 +15,8 @@ import torch from botorch.acquisition import AcquisitionFunction + +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from botorch.exceptions.errors import UnsupportedError from botorch.models import ModelListGP from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel @@ -25,7 +27,7 @@ from torch.nn import Module -class ProximalAcquisitionFunction(AcquisitionFunction): +class ProximalAcquisitionFunction(AbstractAcquisitionFunctionWrapper): """A wrapper around AcquisitionFunctions to add proximal weighting of the acquisition function. The acquisition function is weighted via a squared exponential centered at the last training point, @@ -70,9 +72,7 @@ def __init__( beta: If not None, apply a softplus transform to the base acquisition function, allows negative base acquisition function values. """ - Module.__init__(self) - - self.acq_func = acq_function + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function) model = self.acq_func.model if hasattr(acq_function, "X_pending"): @@ -80,7 +80,6 @@ def __init__( raise UnsupportedError( "Proximal acquisition function requires `X_pending` to be None." ) - self.X_pending = acq_function.X_pending self.register_buffer("proximal_weights", proximal_weights) self.register_buffer( @@ -91,6 +90,12 @@ def __init__( _validate_model(model, proximal_weights) + def set_X_pending(self, X_pending: Optional[Tensor]) -> None: + r"""Sets the `X_pending` of the base acquisition function.""" + raise UnsupportedError( + "Proximal acquisition function does not support `X_pending`." + ) + @t_batch_mode_transform(expected_q=1, assert_output_shape=False) def forward(self, X: Tensor) -> Tensor: r"""Evaluate base acquisition function with proximal weighting. diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 486fdd0cff..ccbbf471b2 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -11,7 +11,7 @@ from __future__ import annotations import math -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from botorch.acquisition import analytic, monte_carlo, multi_objective # noqa F401 @@ -22,6 +22,7 @@ MCAcquisitionObjective, PosteriorTransform, ) +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from botorch.exceptions.errors import UnsupportedError from botorch.models.fully_bayesian import MCMC_DIM from botorch.models.model import Model @@ -253,6 +254,18 @@ def objective(Y: Tensor, X: Optional[Tensor] = None): return -(lb.clamp_max(0.0)) +def isinstance_af( + __obj: object, + __class_or_tuple: Union[type, tuple[Union[type, tuple[Any, ...]], ...]], +) -> bool: + r"""A variant of isinstance first checks for the acq_func attribute on wrapped acquisition functions.""" + if isinstance(__obj, AbstractAcquisitionFunctionWrapper): + isinstance_base_af = isinstance(__obj.acq_func, __class_or_tuple) + else: + isinstance_base_af = False + return isinstance_base_af or isinstance(__obj, __class_or_tuple) + + def is_nonnegative(acq_function: AcquisitionFunction) -> bool: r"""Determine whether a given acquisition function is non-negative. @@ -267,7 +280,7 @@ def is_nonnegative(acq_function: AcquisitionFunction) -> bool: >>> qEI = qExpectedImprovement(model, best_f=0.1) >>> is_nonnegative(qEI) # returns True """ - return isinstance( + return isinstance_af( acq_function, ( analytic.ExpectedImprovement, diff --git a/botorch/acquisition/wrapper.py b/botorch/acquisition/wrapper.py new file mode 100644 index 0000000000..08dfbd2849 --- /dev/null +++ b/botorch/acquisition/wrapper.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +A wrapper classes around AcquisitionFunctions to modify inputs and outputs. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional + +from botorch.acquisition.acquisition import AcquisitionFunction +from torch import Tensor +from torch.nn import Module + + +class AbstractAcquisitionFunctionWrapper(AcquisitionFunction, ABC): + r"""Abstract acquisition wrapper.""" + + def __init__(self, acq_function: AcquisitionFunction) -> None: + Module.__init__(self) + self.acq_func = acq_function + + @property + def X_pending(self) -> Optional[Tensor]: + r"""Return the `X_pending` of the base acquisition function.""" + try: + return self.acq_func.X_pending + except (ValueError, AttributeError): + raise ValueError( + f"Base acquisition function {type(self.acq_func).__name__} " + "does not have an `X_pending` attribute." + ) + + def set_X_pending(self, X_pending: Optional[Tensor]) -> None: + r"""Sets the `X_pending` of the base acquisition function.""" + self.acq_func.set_X_pending(X_pending) + + @abstractmethod + def forward(self, X: Tensor) -> Tensor: + r"""Evaluate the wrapped acquisition function on the candidate set X. + + Args: + X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim + design points each. + + Returns: + A `(b)`-dim Tensor of acquisition function values at the given + design points `X`. + """ + pass # pragma: no cover diff --git a/botorch/models/transforms/factory.py b/botorch/models/transforms/factory.py index 847fdf1b7c..486dbc3125 100644 --- a/botorch/models/transforms/factory.py +++ b/botorch/models/transforms/factory.py @@ -10,7 +10,9 @@ from typing import Dict, List, Optional from botorch.models.transforms.input import ( + AnalyticProbabilisticReparameterizationInputTransform, ChainedInputTransform, + MCProbabilisticReparameterizationInputTransform, Normalize, OneHotToNumeric, Round, @@ -123,3 +125,83 @@ def get_rounding_input_transform( tf.to(dtype=one_hot_bounds.dtype, device=one_hot_bounds.device) tf.eval() return tf + + +def get_probabilistic_reparameterization_input_transform( + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + use_analytic: bool = False, + mc_samples: int = 128, + resample: bool = False, + tau: float = 0.1, +) -> ChainedInputTransform: + r"""Construct InputTransform for Probabilistic Reparameterization. + + Note: this is intended to be used only for acquisition optimization + in via the AnalyticProbabilisticReparameterization and + MCProbabilisticReparameterization classes. This is not intended to be + attached to a botorch Model. + + See [Daulton2022bopr]_ for details. + + Args: + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + use_analytic: A boolean indicating whether to use analytic + probabilistic reparameterization. + mc_samples: The number of MC samples for MC probabilistic + reparameterization. + resample: A boolean indicating whether to resample with MC + probabilistic reparameterization on each forward pass. + tau: The temperature parameter used to determine the probabilities. + + Returns: + The probabilistic reparameterization input transformation. + """ + tfs = OrderedDict() + if integer_indices is not None and len(integer_indices) > 0: + # unnormalize to integer space + tfs["unnormalize"] = Normalize( + d=one_hot_bounds.shape[1], + bounds=one_hot_bounds, + indices=integer_indices, + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=False, + reverse=True, + ) + if use_analytic: + tfs["round"] = AnalyticProbabilisticReparameterizationInputTransform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + tau=tau, + ) + else: + tfs["round"] = MCProbabilisticReparameterizationInputTransform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + resample=resample, + mc_samples=mc_samples, + tau=tau, + ) + if integer_indices is not None and len(integer_indices) > 0: + # normalize to unit cube + tfs["normalize"] = Normalize( + d=one_hot_bounds.shape[1], + bounds=one_hot_bounds, + indices=integer_indices, + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=False, + reverse=False, + ) + tf = ChainedInputTransform(**tfs) + tf.eval() + return tf diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 09310163b5..0bc649dedf 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -25,6 +25,7 @@ from botorch.models.transforms.utils import subset_transform from botorch.models.utils import fantasize from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE +from botorch.utils.sampling import draw_sobol_samples from gpytorch import Module as GPyTorchModule from gpytorch.constraints import GreaterThan from gpytorch.priors import Prior @@ -1503,3 +1504,574 @@ def equals(self, other: InputTransform) -> bool: and (self.transform_on_fantasize == other.transform_on_fantasize) and self.categorical_features == other.categorical_features ) + + +class AnalyticProbabilisticReparameterizationInputTransform(InputTransform, Module): + r"""An input transform to prepare inputs for analytic PR. + + See [Daulton2022bopr]_ for details. + + This will typically be used in conjunction with normalization as + follows: + + In eval() mode (i.e. after training), the inputs pass + would typically be normalized to the unit cube (e.g. during candidate + optimization). + 1. These are unnormalized back to the raw input space. + 2. The discrete values are created. + 3. All values are normalized to the unitcube. + + TODO: consolidate this with MCProbabilisticReparameterizationInputTransform. + + """ + + def __init__( + self, + one_hot_bounds: Tensor = None, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + transform_on_train: bool = False, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + tau: float = 0.1, + ) -> None: + r"""Initialize transform. + + Args: + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer inputs. + categorical_features: The indices and cardinality of + each categorical feature. The features are assumed + to be one-hot encoded. TODO: generalize to support + alternative representations. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: A boolean indicating whether to apply the + transform when called from within a `fantasize` call. Default: True. + mc_samples: The number of MC samples. + resample: A boolean indicating whether to resample base samples + at each forward pass. + tau: The temperature parameter. + """ + super().__init__() + if integer_indices is None and categorical_features is None: + raise ValueError( + "integer_indices and/or categorical_features must be provided." + ) + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + discrete_indices = [] + if integer_indices is not None and len(integer_indices) > 0: + self.register_buffer( + "integer_indices", + torch.tensor( + integer_indices, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer("integer_bounds", one_hot_bounds[:, integer_indices]) + discrete_indices += integer_indices + else: + self.integer_indices = None + self.categorical_features = categorical_features + if self.categorical_features is not None: + self.categorical_start_idx = min(self.categorical_features.keys()) + # check that the trailing dimensions are categoricals + end = self.categorical_start_idx + err_msg = ( + f"{self.__class__.__name__} requires that the categorical " + "parameters are the rightmost elements." + ) + for start, card in self.categorical_features.items(): + # the end of one one-hot representation should be followed + # by the start of the next + if end != start: + raise ValueError(err_msg) + end = start + card + if end != one_hot_bounds.shape[1]: + # check end + raise ValueError(err_msg) + categorical_starts = [] + categorical_ends = [] + if self.categorical_features is not None: + start = None + for i, n_categories in categorical_features.items(): + if start is None: + start = i + end = start + n_categories + categorical_starts.append(start) + categorical_ends.append(end) + discrete_indices += list(range(start, end)) + start = end + self.register_buffer( + "discrete_indices", + torch.tensor( + discrete_indices, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer( + "categorical_starts", + torch.tensor( + categorical_starts, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer( + "categorical_ends", + torch.tensor( + categorical_ends, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.tau = tau + # create cartesian product of discrete options + discrete_options = [] + dim = one_hot_bounds.shape[1] + # get number of discrete parameters + num_discrete_params = 0 + if self.integer_indices is not None: + num_discrete_params += self.integer_indices.shape[0] + if self.categorical_features is not None: + num_discrete_params += len(self.categorical_features) + # add zeros for continuous params to simplify code + for _ in range(dim - len(discrete_indices)): + discrete_options.append( + torch.zeros( + 1, + dtype=torch.long, + device=one_hot_bounds.device, + ) + ) + if integer_indices is not None: + for i in range(self.integer_bounds.shape[-1]): + discrete_options.append( + torch.arange( + self.integer_bounds[0, i], + self.integer_bounds[1, i] + 1, + dtype=torch.long, + device=one_hot_bounds.device, + ) + ) + categorical_start_idx = len(discrete_options) + if categorical_features is not None: + for idx in sorted(categorical_features.keys()): + cardinality = categorical_features[idx] + discrete_options.append( + torch.arange( + cardinality, dtype=torch.long, device=one_hot_bounds.device + ) + ) + # categoricals are in numeric representation + all_discrete_options = torch.cartesian_prod(*discrete_options) + # one-hot encode the categoricals + if categorical_features is not None and len(categorical_features) > 0: + X_categ = torch.empty( + *all_discrete_options.shape[:-1], sum(categorical_features.values()) + ) + start = 0 + for i, (idx, cardinality) in enumerate( + sorted(categorical_features.items(), key=lambda kv: kv[0]) + ): + start = idx - categorical_start_idx + X_categ[..., start : start + cardinality] = one_hot( + all_discrete_options[..., i], + num_classes=cardinality, + ).to(X_categ) + all_discrete_options = torch.cat( + [all_discrete_options[..., : -len(categorical_features)], X_categ], + dim=-1, + ) + self.register_buffer("all_discrete_options", all_discrete_options) + + def get_rounding_prob(self, X: Tensor) -> Tensor: + # todo consolidate this the MCProbabilisticReparameterizationInputTransform + X_prob = X.detach().clone() + if self.integer_indices is not None: + # compute probabilities for integers + X_int = X_prob[..., self.integer_indices] + X_int_abs = X_int.abs() + offset = X_int_abs.floor() + if self.tau is not None: + X_prob[..., self.integer_indices] = torch.sigmoid( + (X_int_abs - offset - 0.5) / self.tau + ) + else: + X_prob[..., self.integer_indices] = X_int_abs - offset + # compute probabilities for categoricals + for start, end in zip(self.categorical_starts, self.categorical_ends): + X_categ = X_prob[..., start:end] + if self.tau is not None: + X_prob[..., start:end] = torch.softmax( + (X_categ - 0.5) / self.tau, dim=-1 + ) + else: + X_prob[..., start:end] = X_categ / X_categ.sum(dim=-1) + return X_prob[..., self.discrete_indices] + + def get_probs(self, X: Tensor) -> Tensor: + """ + Args: + X: a `batch_shape x n x d`-dim tensor + + Returns: + A `batch_shape x n_discrete x n`-dim tensors of probabilities of each discrete config under X. + """ + # note this method should be differentiable + X_prob = torch.ones( + *X.shape[:-2], + self.all_discrete_options.shape[0], + X.shape[-2], + dtype=X.dtype, + device=X.device, + ) + # n_discrete x batch_shape x n x d + all_discrete_options = self.all_discrete_options.view( + *([1] * (X.ndim - 2)), self.all_discrete_options.shape[0], *X.shape[-2:] + ).expand(*X.shape[:-2], self.all_discrete_options.shape[0], *X.shape[-2:]) + X = X.unsqueeze(-3) + if self.integer_indices is not None: + # compute probabilities for integers + X_int = X[..., self.integer_indices] + X_int_abs = X_int.abs() + offset = X_int_abs.floor() + # note we don't actually need the sigmoid here + X_prob_int = torch.sigmoid((X_int_abs - offset - 0.5) / self.tau) + # X_prob_int = X_int_abs - offset + for int_idx, idx in enumerate(self.integer_indices): + offset_i = offset[..., int_idx] + all_discrete_i = all_discrete_options[..., idx] + diff = (offset_i + 1) - all_discrete_i + round_up_mask = diff == 0 + round_down_mask = diff == 1 + neither_mask = ~(round_up_mask | round_down_mask) + prob = X_prob_int[..., int_idx].expand(round_up_mask.shape) + # need to be careful with in-place ops here for autograd + X_prob[round_up_mask] = X_prob[round_up_mask] * prob[round_up_mask] + X_prob[round_down_mask] = X_prob[round_down_mask] * ( + 1 - prob[round_down_mask] + ) + X_prob[neither_mask] = X_prob[neither_mask] * 0 + + # compute probabilities for categoricals + for start, end in zip(self.categorical_starts, self.categorical_ends): + X_categ = X[..., start:end] + X_prob_c = torch.softmax((X_categ - 0.5) / self.tau, dim=-1).expand( + *X_categ.shape[:-3], all_discrete_options.shape[-3], *X_categ.shape[-2:] + ) + for i in range(X_prob_c.shape[-1]): + mask = all_discrete_options[..., start + i] == 1 + X_prob[mask] = X_prob[mask] * X_prob_c[..., i][mask] + + return X_prob + + def transform(self, X: Tensor) -> Tensor: + r"""Round the inputs. + + This is not sample-path differentiable. + + Args: + X: A `batch_shape x 1 x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n_discrete x n x d`-dim tensor of rounded inputs. + """ + n_discrete = self.discrete_indices.shape[0] + all_discrete_options = self.all_discrete_options.view( + *([1] * (X.ndim - 3)), self.all_discrete_options.shape[0], *X.shape[-2:] + ).expand(*X.shape[:-3], self.all_discrete_options.shape[0], *X.shape[-2:]) + if X.shape[-1] > n_discrete: + X = X.expand( + *X.shape[:-3], self.all_discrete_options.shape[0], *X.shape[-2:] + ) + return torch.cat( + [X[..., :-n_discrete], all_discrete_options[..., -n_discrete:]], dim=-1 + ) + return all_discrete_options + + def equals(self, other: InputTransform) -> bool: + r"""Check if another input transform is equivalent. + + Args: + other: Another input transform. + + Returns: + A boolean indicating if the other transform is equivalent. + """ + # TODO: update this + return super().equals(other=other) and torch.equal( + self.integer_indices, other.integer_indices + ) + + +class MCProbabilisticReparameterizationInputTransform(InputTransform, Module): + r"""An input transform to prepare inputs for analytic PR. + + See [Daulton2022bopr]_ for details. + + This will typically be used in conjunction with normalization as + follows: + + In eval() mode (i.e. after training), the inputs pass + would typically be normalized to the unit cube (e.g. during candidate + optimization). + 1. These are unnormalized back to the raw input space. + 2. The discrete ordinal valeus are sampled. + 3. All values are normalized to the unitcube. + """ + + def __init__( + self, + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + transform_on_train: bool = False, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + mc_samples: int = 128, + resample: bool = False, + tau: float = 0.1, + ) -> None: + r"""Initialize transform. + + Args: + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer inputs. + categorical_features: The indices and cardinality of + each categorical feature. The features are assumed + to be one-hot encoded. TODO: generalize to support + alternative representations. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: A boolean indicating whether to apply the + transform when called from within a `fantasize` call. Default: True. + mc_samples: The number of MC samples. + resample: A boolean indicating whether to resample base samples + at each forward pass. + tau: The temperature parameter. + """ + super().__init__() + if integer_indices is None and categorical_features is None: + raise ValueError( + "integer_indices and/or categorical_features must be provided." + ) + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + discrete_indices = [] + if integer_indices is not None and len(integer_indices) > 0: + self.register_buffer( + "integer_indices", torch.tensor(integer_indices, dtype=torch.long) + ) + discrete_indices += integer_indices + else: + self.integer_indices = None + self.categorical_features = categorical_features + if self.categorical_features is not None: + self.categorical_start_idx = min(self.categorical_features.keys()) + # check that the trailing dimensions are categoricals + end = self.categorical_start_idx + err_msg = ( + f"{self.__class__.__name__} requires that the categorical " + "parameters are the rightmost elements." + ) + for start, card in self.categorical_features.items(): + # the end of one one-hot representation should be followed + # by the start of the next + if end != start: + raise ValueError(err_msg) + end = start + card + if end != one_hot_bounds.shape[1]: + # check end + raise ValueError(err_msg) + categorical_starts = [] + categorical_ends = [] + if self.categorical_features is not None: + start = None + for i, n_categories in categorical_features.items(): + if start is None: + start = i + end = start + n_categories + categorical_starts.append(start) + categorical_ends.append(end) + discrete_indices += list(range(start, end)) + start = end + self.register_buffer( + "discrete_indices", + torch.tensor( + discrete_indices, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer( + "categorical_starts", + torch.tensor( + categorical_starts, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer( + "categorical_ends", + torch.tensor( + categorical_ends, dtype=torch.long, device=one_hot_bounds.device + ), + ) + if integer_indices is None: + self.register_buffer( + "integer_bounds", + torch.tensor([], dtype=torch.long, device=one_hot_bounds.device), + ) + else: + self.register_buffer("integer_bounds", one_hot_bounds[:, integer_indices]) + self.mc_samples = mc_samples + self.resample = resample + self.tau = tau + + def get_rounding_prob(self, X: Tensor) -> Tensor: + X_prob = X.detach().clone() + if self.integer_indices is not None: + # compute probabilities for integers + X_int = X_prob[..., self.integer_indices] + X_int_abs = X_int.abs() + offset = X_int_abs.floor() + if self.tau is not None: + X_prob[..., self.integer_indices] = torch.sigmoid( + (X_int_abs - offset - 0.5) / self.tau + ) + else: + X_prob[..., self.integer_indices] = X_int_abs - offset + # compute probabilities for categoricals + for start, end in zip(self.categorical_starts, self.categorical_ends): + X_categ = X_prob[..., start:end] + if self.tau is not None: + X_prob[..., start:end] = torch.softmax( + (X_categ - 0.5) / self.tau, dim=-1 + ) + else: + X_prob[..., start:end] = X_categ / X_categ.sum(dim=-1) + return X_prob[..., self.discrete_indices] + + def transform(self, X: Tensor) -> Tensor: + r"""Round the inputs. + + This is not sample-path differentiable. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of rounded inputs. + """ + X_expanded = X.expand(*X.shape[:-3], self.mc_samples, *X.shape[-2:]).clone() + X_prob = self.get_rounding_prob(X=X) + if self.integer_indices is not None: + X_int = X[..., self.integer_indices].detach() + assert X.ndim > 1 + if X.ndim == 2: + X.unsqueeze(-1) + if ( + not hasattr(self, "base_samples") + or self.base_samples.shape[-2:] != X_int.shape[-2:] + or self.resample + ): + # construct sobol base samples + bounds = torch.zeros( + 2, X_int.shape[-1], dtype=X_int.dtype, device=X_int.device + ) + bounds[1] = 1 + self.register_buffer( + "base_samples", + draw_sobol_samples( + bounds=bounds, + n=self.mc_samples, + q=X_int.shape[-2], + seed=torch.randint(0, 100000, (1,)).item(), + ), + ) + X_int_abs = X_int.abs() + # perform exact rounding + is_negative = X_int < 0 + offset = X_int_abs.floor() + prob = X_prob[..., : self.integer_indices.shape[0]] + rounding_component = (prob >= self.base_samples).to( + dtype=X.dtype, + ) + X_abs_rounded = offset + rounding_component + X_int_new = (-1) ** is_negative.to(offset) * X_abs_rounded + # clamp to bounds + X_expanded[..., self.integer_indices] = torch.minimum( + torch.maximum(X_int_new, self.integer_bounds[0]), self.integer_bounds[1] + ) + + # sample for categoricals + if self.categorical_features is not None and len(self.categorical_features) > 0: + if ( + not hasattr(self, "base_samples_categorical") + or self.base_samples_categorical.shape[-2] != X.shape[-2] + or self.resample + ): + bounds = torch.zeros( + 2, len(self.categorical_features), dtype=X.dtype, device=X.device + ) + bounds[1] = 1 + self.register_buffer( + "base_samples_categorical", + draw_sobol_samples( + bounds=bounds, + n=self.mc_samples, + q=X.shape[-2], + seed=torch.randint(0, 100000, (1,)).item(), + ), + ) + + # sample from multinomial as argmin_c [sample_c * exp(-x_c)] + sample_d_start_idx = 0 + X_categ_prob = X_prob + if self.integer_indices is not None: + n_ints = self.integer_indices.shape[0] + if n_ints > 0: + X_categ_prob = X_prob[..., n_ints:] + + for i, cardinality in enumerate(self.categorical_features.values()): + sample_d_end_idx = sample_d_start_idx + cardinality + start = self.categorical_starts[i] + end = self.categorical_ends[i] + cum_prob = X_categ_prob[ + ..., sample_d_start_idx:sample_d_end_idx + ].cumsum(dim=-1) + categories = ( + ( + (cum_prob > self.base_samples_categorical[..., i : i + 1]) + .long() + .cumsum(dim=-1) + == 1 + ) + .long() + .argmax(dim=-1) + ) + # one-hot encode + X_expanded[..., start:end] = one_hot( + categories, num_classes=cardinality + ).to(X) + sample_d_start_idx = sample_d_end_idx + + return X_expanded + + def equals(self, other: InputTransform) -> bool: + r"""Check if another input transform is equivalent. + + Args: + other: Another input transform. + + Returns: + A boolean indicating if the other transform is equivalent. + """ + return ( + super().equals(other=other) + and (self.resample == other.resample) + and torch.equal(self.base_samples, other.base_samples) + and torch.equal(self.integer_indices, other.integer_indices) + ) diff --git a/sphinx/source/acquisition.rst b/sphinx/source/acquisition.rst index 79f529826a..a5a429e46b 100644 --- a/sphinx/source/acquisition.rst +++ b/sphinx/source/acquisition.rst @@ -21,6 +21,11 @@ Analytic Acquisition Function API .. autoclass:: AnalyticAcquisitionFunction :members: +Acquisition Function Wrapper API +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.acquisition.wrapper + :members: + Cached Cholesky Acquisition Function API ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.cached_cholesky @@ -65,7 +70,7 @@ Multi-Objective Analytic Acquisition Functions .. automodule:: botorch.acquisition.multi_objective.analytic :members: :exclude-members: MultiObjectiveAnalyticAcquisitionFunction - + Multi-Objective Joint Entropy Search Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.joint_entropy_search @@ -86,7 +91,7 @@ Multi-Objective Multi-Fidelity Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.multi_fidelity :members: - + Multi-Objective Predictive Entropy Search Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search @@ -175,6 +180,11 @@ Penalized Acquisition Function Wrapper .. automodule:: botorch.acquisition.penalized :members: +Probabilistic Reparameterization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.acquisition.probabilistic_reparameterization + :members: + Proximal Acquisition Function Wrapper ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.proximal diff --git a/test/acquisition/test_fixed_feature.py b/test/acquisition/test_fixed_feature.py index 8dcc02f1df..b8f570e7e1 100644 --- a/test/acquisition/test_fixed_feature.py +++ b/test/acquisition/test_fixed_feature.py @@ -87,7 +87,7 @@ def test_fixed_features(self): qEI_ff.set_X_pending(X_pending[..., :-1]) self.assertAllClose(qEI.X_pending, X_pending) # test setting to None - qEI_ff.X_pending = None + qEI_ff.set_X_pending(None) self.assertIsNone(qEI_ff.X_pending) # test gradient diff --git a/test/acquisition/test_proximal.py b/test/acquisition/test_proximal.py index 795daa1b34..e17536ddd0 100644 --- a/test/acquisition/test_proximal.py +++ b/test/acquisition/test_proximal.py @@ -209,9 +209,15 @@ def test_proximal(self): # test for x_pending points pending_acq = DummyAcquisitionFunction(model) - pending_acq.set_X_pending(torch.rand(3, 3, device=self.device, dtype=dtype)) + X_pending = torch.rand(3, 3, device=self.device, dtype=dtype) + pending_acq.set_X_pending(X_pending) with self.assertRaises(UnsupportedError): ProximalAcquisitionFunction(pending_acq, proximal_weights) + # test setting pending points + pending_acq.set_X_pending(None) + af = ProximalAcquisitionFunction(pending_acq, proximal_weights) + with self.assertRaises(UnsupportedError): + af.set_X_pending(X_pending) # test model with multi-batch training inputs train_X = torch.rand(5, 2, 3, device=self.device, dtype=dtype) diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index d12b5f6da4..39b8017ea2 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -8,7 +8,8 @@ from unittest import mock import torch -from botorch.acquisition import monte_carlo +from botorch.acquisition import analytic, monte_carlo, multi_objective +from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction from botorch.acquisition.multi_objective import ( MCMultiOutputObjective, monte_carlo as moo_monte_carlo, @@ -18,10 +19,13 @@ MCAcquisitionObjective, ScalarizedPosteriorTransform, ) +from botorch.acquisition.proximal import ProximalAcquisitionFunction from botorch.acquisition.utils import ( expand_trace_observations, get_acquisition_function, get_infeasible_cost, + is_nonnegative, + isinstance_af, project_to_sample_points, project_to_target_fidelity, prune_inferior_points, @@ -606,6 +610,61 @@ def test_get_infeasible_cost(self): self.assertAllClose(M4, torch.tensor([1.0], **tkwargs)) +class TestIsNonnegative(BotorchTestCase): + def test_is_nonnegative(self): + nonneg_afs = ( + analytic.ExpectedImprovement, + analytic.ConstrainedExpectedImprovement, + analytic.ProbabilityOfImprovement, + analytic.NoisyExpectedImprovement, + monte_carlo.qExpectedImprovement, + monte_carlo.qNoisyExpectedImprovement, + monte_carlo.qProbabilityOfImprovement, + multi_objective.analytic.ExpectedHypervolumeImprovement, + multi_objective.monte_carlo.qExpectedHypervolumeImprovement, + multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement, + ) + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, device=self.device), + variance=torch.ones(1, 1, device=self.device), + ) + ) + acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) + with mock.patch( + "botorch.acquisition.utils.isinstance_af", return_value=True + ) as mock_isinstance_af: + self.assertTrue(is_nonnegative(acq_function=acq_func)) + mock_isinstance_af.assert_called_once() + cargs, _ = mock_isinstance_af.call_args + self.assertIs(cargs[0], acq_func) + self.assertEqual(cargs[1], nonneg_afs) + acq_func = analytic.UpperConfidenceBound(model=mm, beta=2.0) + self.assertFalse(is_nonnegative(acq_function=acq_func)) + + +class TestIsinstanceAf(BotorchTestCase): + def test_isinstance_af(self): + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, device=self.device), + variance=torch.ones(1, 1, device=self.device), + ) + ) + acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) + self.assertTrue(isinstance_af(acq_func, analytic.ExpectedImprovement)) + self.assertFalse(isinstance_af(acq_func, analytic.UpperConfidenceBound)) + wrapped_af = FixedFeatureAcquisitionFunction( + acq_function=acq_func, d=2, columns=[1], values=[0.0] + ) + # test base af class + self.assertTrue(isinstance_af(wrapped_af, analytic.ExpectedImprovement)) + self.assertFalse(isinstance_af(wrapped_af, analytic.UpperConfidenceBound)) + # test wrapper class + self.assertTrue(isinstance_af(wrapped_af, FixedFeatureAcquisitionFunction)) + self.assertFalse(isinstance_af(wrapped_af, ProximalAcquisitionFunction)) + + class TestPruneInferiorPoints(BotorchTestCase): def test_prune_inferior_points(self): for dtype in (torch.float, torch.double): diff --git a/test/acquisition/test_wrapper.py b/test/acquisition/test_wrapper.py new file mode 100644 index 0000000000..e35175fb9b --- /dev/null +++ b/test/acquisition/test_wrapper.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from botorch.acquisition.analytic import ExpectedImprovement +from botorch.acquisition.monte_carlo import qExpectedImprovement +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper +from botorch.exceptions.errors import UnsupportedError +from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior + + +class DummyWrapper(AbstractAcquisitionFunctionWrapper): + def forward(self, X): + return self.acq_func(X) + + +class TestAbstractAcquisitionFunctionWrapper(BotorchTestCase): + def test_abstract_acquisition_function_wrapper(self): + for dtype in (torch.float, torch.double): + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, dtype=dtype, device=self.device), + variance=torch.ones(1, 1, dtype=dtype, device=self.device), + ) + ) + acq_func = ExpectedImprovement(model=mm, best_f=-1.0) + wrapped_af = DummyWrapper(acq_function=acq_func) + self.assertIs(wrapped_af.acq_func, acq_func) + # test forward + X = torch.rand(1, 1, dtype=dtype, device=self.device) + with torch.no_grad(): + wrapped_val = wrapped_af(X) + af_val = acq_func(X) + self.assertEqual(wrapped_val.item(), af_val.item()) + + # test X_pending + with self.assertRaises(ValueError): + self.assertIsNone(wrapped_af.X_pending) + with self.assertRaises(UnsupportedError): + wrapped_af.set_X_pending(X) + acq_func = qExpectedImprovement(model=mm, best_f=-1.0) + wrapped_af = DummyWrapper(acq_function=acq_func) + self.assertIsNone(wrapped_af.X_pending) + wrapped_af.set_X_pending(X) + self.assertTrue(torch.equal(X, wrapped_af.X_pending)) + self.assertTrue(torch.equal(X, acq_func.X_pending)) + wrapped_af.set_X_pending(None) + self.assertIsNone(wrapped_af.X_pending) + self.assertIsNone(acq_func.X_pending) diff --git a/tutorials/discrete_mixed_bo.ipynb b/tutorials/discrete_mixed_bo.ipynb new file mode 100644 index 0000000000..38bea5e867 --- /dev/null +++ b/tutorials/discrete_mixed_bo.ipynb @@ -0,0 +1,719 @@ +{ + "metadata": { + "dataExplorerConfig": {}, + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3", + "cinder_runtime": true, + "ipyflow_runtime": false, + "metadata": { + "is_prebuilt": false, + "kernel_name": "ae_local", + "cinder_runtime": true, + "ipyflow_runtime": false + } + }, + "last_server_session_id": "96862577-2347-40d6-a7f0-e7fda4e0662d", + "last_kernel_id": "28f4e6d7-8588-4a44-a9cd-72f1c86fe424", + "last_base_url": "https://10809.od.fbinfra.net:443/", + "last_msg_id": "4fc1bf58-e85662233f122503abcfda0d_71", + "outputWidgetContext": {} + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "originalKey": "00018e33-90ca-4f63-b741-fe1fd43ca7db", + "showInput": false, + "collapsed": false + }, + "source": [ + "## Bayesian Optimization over Discrete and Mixed Spaces via Probabilistic Reparameterization\n", + "\n", + "In this tutorial, we illustrate how to perform Bayesian Optimization (BO) over discrete and mixed spaces via probabilistic reparameterization.\n", + "\n", + "The key idea is that we can optimize an acquisition function $\\alpha(x, z)$ with discrete variables $z$ (and potentially continuous variables $ x$) by reparameterizing the discrete variables with random discrete varaibles $ Z$ that are parameterized by continuous parameters $\\theta$. This reparameterization enables optimizing the acquisition function by optimizing the following probabilistic objective:\n", + "$$\\mathbb E_{Z \\sim P(Z|\\theta)}[\\alpha(x, Z)].$$\n", + "\n", + "The probabilistic objective is differentiable with respect to $\\theta$ (and $x$ so long as the acquisition function is differentiable with respect to $x$) and hence we can optimize the acquisition function with gradients.\n", + "\n", + "In this tutorial, we demonstrate how to use both an analytic version of probabilistic reparameterization (suitable when there are less than a few thousand discrete options) and a scalable Monte Carlo variant in BoTorch.\n", + "\n", + "S. Daulton, X. Wan, D. Eriksson, M. Balandat, M. A. Osborne, E. Bakshy. [Bayesian Optimization over Discrete and Mixed Spaces via Probabilistic Reparameterization](https://arxiv.org/abs/2210.10199), NeurIPS, 2022. " + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "029a9a6f-8db6-4d55-b122-d9ba064765ed", + "collapsed": false, + "requestMsgId": "f795e05e-6be2-4567-a6c6-8abe360c1e7f", + "customOutput": null, + "executionStartTime": 1669843260741, + "executionStopTime": 1669843260795 + }, + "source": [ + "import os\n", + "from typing import Optional\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from botorch.models.transforms.factory import get_rounding_input_transform\n", + "from botorch.test_functions.synthetic import SyntheticTestFunction\n", + "from botorch.utils.sampling import draw_sobol_samples, manual_seed\n", + "from botorch.utils.transforms import unnormalize\n", + "from torch import Tensor\n", + "\n", + "\n", + "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n", + "dtype = torch.double\n", + "tkwargs = {\"dtype\": dtype, \"device\": device}\n", + "SMOKE_TEST = os.environ.get(\"SMOKE_TEST\")" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "I1130 132056.081 _utils_internal.py:179] NCCL_DEBUG env var is set to None\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "I1130 132056.082 _utils_internal.py:188] NCCL_DEBUG is INFO from /etc/nccl.conf\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "4ba4e568-0ef2-430e-aecb-dae8889d6664", + "showInput": false, + "collapsed": false + }, + "source": [ + "### Problem setup\n", + "\n", + "Setup a mixed Ackley proble with 10 binary parameters and 3 continuous parameters." + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "b4930ece-f6e2-43a7-ae98-6ff272608b98", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "5f741506-5e09-4ee2-ba37-936720d3177d", + "customOutput": null, + "executionStartTime": 1669843260807, + "executionStopTime": 1669843261757 + }, + "source": [ + "from botorch.test_functions.synthetic import Ackley\n", + "dim = 13\n", + "base_function = Ackley(dim=dim, negate=True).to(**tkwargs)\n", + "# restrict ackley search space\n", + "base_function.bounds[0, :-3] = 0\n", + "base_function.bounds[1] = 1\n", + "base_function.bounds[0, -3:] = -1\n", + "# define integer bounds (binary)\n", + "integer_bounds = torch.zeros(2, dim - 3, **tkwargs)\n", + "integer_bounds[1] = 1 # 2 values\n", + "rounding_bounds = base_function.bounds.clone()\n", + "integer_indices = list(range(dim - 3))\n", + "rounding_bounds[:, integer_indices] = integer_bounds\n", + "standard_bounds = torch.zeros_like(rounding_bounds)\n", + "standard_bounds[1] = 1" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "223e42ea-10db-4510-b2b6-d51dbf114052", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "6a161416-817d-48d7-bacb-816d02b8bfca", + "customOutput": null, + "executionStartTime": 1669843261778, + "executionStopTime": 1669843261787 + }, + "source": [ + "\n", + "# construct a rounding function for initialization (equal probability for all discrete values)\n", + "init_exact_rounding_func = get_rounding_input_transform(\n", + " one_hot_bounds=rounding_bounds, integer_indices=integer_indices, initialization=True\n", + ")\n", + "# construct a rounding function\n", + "exact_rounding_func = get_rounding_input_transform(\n", + " one_hot_bounds=rounding_bounds, integer_indices=integer_indices, initialization=False\n", + ")" + ], + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "d569119b-39d6-4de6-a5fc-83057239323f", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "1c87ba69-f582-43e7-84f0-879c7672481d", + "customOutput": null, + "executionStartTime": 1669843261804, + "executionStopTime": 1669843261812 + }, + "source": [ + "def eval_problem(X):\n", + " # apply the exact rounding function to make sure\n", + " # that discrete parameters are discretized\n", + " X = exact_rounding_func(X)\n", + " # unnormalize to the problem space\n", + " raw_X = unnormalize(X, base_function.bounds)\n", + " return base_function(raw_X).unsqueeze(-1)" + ], + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "b3c7d450-b82d-4e69-8c0c-667dc2ba6f17", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "3f711388-d3d6-479f-aff4-85c456647121", + "customOutput": null, + "executionStartTime": 1669843261842, + "executionStopTime": 1669843261865 + }, + "source": [ + "def generate_initial_data(n):\n", + " r\"\"\"\n", + " Generates the initial data for the experiments.\n", + " Args:\n", + " n: Number of training points..\n", + " Returns:\n", + " The train_X and train_Y. `n x d` and `n x 1`.\n", + " \"\"\"\n", + " raw_x = draw_sobol_samples(bounds=standard_bounds, n=n, q=1).squeeze(-2)\n", + " train_x = init_exact_rounding_func(raw_x)\n", + " train_obj = eval_problem(train_x)\n", + " return train_x, train_obj" + ], + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "c262e98f-924d-414e-889f-7e65a37d3689", + "showInput": false, + "collapsed": false + }, + "source": [ + "#### Model initialization\n", + "\n", + "We use a `FixedNoiseGP` to model the outcome. The models are initialized with 20 quasi-random points. We use an isotropic kernel over the binary parameters and an ARD kernel over the continuous parameters." + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "0d3cd746-3818-47c1-b16a-622c6035c264", + "collapsed": false, + "requestMsgId": "d499c927-365f-488c-82df-dafc04786744", + "customOutput": null, + "executionStartTime": 1669843261876, + "executionStopTime": 1669843261884 + }, + "source": [ + "from typing import Dict, List, Optional\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from botorch.models import FixedNoiseGP\n", + "from botorch.models.kernels import CategoricalKernel\n", + "from gpytorch.constraints import GreaterThan, Interval\n", + "from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel\n", + "from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood\n", + "from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior\n", + "from torch import Tensor\n", + "\n", + "\n", + "def get_kernel(dim: int, binary_dims: List[int]) -> Kernel:\n", + " \"\"\"Helper function for kernel construction.\"\"\"\n", + " # ard kernel for continuous features\n", + " cont_dims = list(set(list(range(dim))) - set(binary_dims))\n", + " cont_kernel = MaternKernel(\n", + " nu=2.5,\n", + " ard_num_dims=len(cont_dims),\n", + " active_dims=cont_dims,\n", + " lengthscale_constraint=Interval(0.1, 20.0),\n", + " )\n", + " # isotropic kernel for binary features\n", + " binary_kernel = MaternKernel(\n", + " nu=2.5,\n", + " ard_num_dims=None,\n", + " active_dims=binary_dims,\n", + " lengthscale_constraint=Interval(0.1, 20.0),\n", + " )\n", + " return ScaleKernel(cont_kernel * binary_kernel)\n", + "\n", + "\n", + "NOISE_SE = 1e-6\n", + "train_yvar = torch.tensor(NOISE_SE**2, device=device, dtype=dtype)\n", + "\n", + "\n", + "def initialize_model(\n", + " train_x, stdized_train_obj, state_dict=None, exact_rounding_func=None\n", + "):\n", + " # define model\n", + " model = FixedNoiseGP(\n", + " train_x,\n", + " stdized_train_obj,\n", + " train_yvar.expand_as(stdized_train_obj),\n", + " input_transform=exact_rounding_func,\n", + " covar_module=get_kernel(dim=dim, binary_dims=integer_indices),\n", + " ).to(train_x)\n", + " mll = ExactMarginalLogLikelihood(model.likelihood, model)\n", + " # load state dict if it is passed\n", + " if state_dict is not None:\n", + " model.load_state_dict(state_dict)\n", + " return mll, model" + ], + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "e45f8a78-36e4-4692-9ce3-7e883f7780cb", + "showInput": false, + "collapsed": false + }, + "source": [ + "#### Define a helper function that performs the essential BO step\n", + "The helper function below takes an acquisition function as an argument, optimizes it, and returns the candidate along with the observed function values. \n", + "\n", + "`optimize_acqf_cont_relax_and_get_observation` uses a continuous relaxation of the discrete parameters and rounds the resulting candidate.\n", + "\n", + "`optimize_acqf_pr_and_get_observation` uses a probabilistic reparameterization." + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "3fe3ad5b-dabf-4f13-b95f-0766afdf8920", + "collapsed": false, + "requestMsgId": "6d1d90c3-d286-4b04-9f70-cf70d175b629", + "customOutput": null, + "executionStartTime": 1669843261906, + "executionStopTime": 1669843261915 + }, + "source": [ + "from botorch.acquisition.probabilistic_reparameterization import (\n", + " AnalyticProbabilisticReparameterization,\n", + " MCProbabilisticReparameterization,\n", + ")\n", + "from botorch.generation.gen import gen_candidates_scipy, gen_candidates_torch\n", + "from botorch.optim import optimize_acqf\n", + "\n", + "NUM_RESTARTS = 20 if not SMOKE_TEST else 2\n", + "RAW_SAMPLES = 1024 if not SMOKE_TEST else 32\n", + "\n", + "\n", + "def optimize_acqf_cont_relax_and_get_observation(acq_func):\n", + " \"\"\"Optimizes the acquisition function, and returns a new candidate and a noisy observation.\"\"\"\n", + " # optimize\n", + " candidates, _ = optimize_acqf(\n", + " acq_function=acq_func,\n", + " bounds=standard_bounds,\n", + " q=1,\n", + " num_restarts=NUM_RESTARTS,\n", + " raw_samples=RAW_SAMPLES, # used for intialization heuristic\n", + " options={\"batch_limit\": 5, \"maxiter\": 200},\n", + " return_best_only=False,\n", + " )\n", + " # round the resulting candidates and take the best across restarts\n", + " candidates = exact_rounding_func(candidates.detach())\n", + " with torch.no_grad():\n", + " af_vals = acq_func(candidates)\n", + " best_idx = af_vals.argmax()\n", + " new_x = candidates[best_idx]\n", + " # observe new values\n", + " exact_obj = eval_problem(new_x)\n", + " return new_x, exact_obj\n", + "\n", + "\n", + "def optimize_acqf_pr_and_get_observation(acq_func, analytic):\n", + " \"\"\"Optimizes the acquisition function, and returns a new candidate and a noisy observation.\"\"\"\n", + " # construct PR\n", + " if analytic:\n", + " pr_acq_func = AnalyticProbabilisticReparameterization(\n", + " acq_function=acq_func,\n", + " one_hot_bounds=rounding_bounds,\n", + " integer_indices=integer_indices,\n", + " batch_limit=128,\n", + " )\n", + " else:\n", + " pr_acq_func = MCProbabilisticReparameterization(\n", + " acq_function=acq_func,\n", + " one_hot_bounds=rounding_bounds,\n", + " integer_indices=integer_indices,\n", + " batch_limit=128,\n", + " mc_samples=4 if SMOKE_TEST else 128,\n", + " )\n", + " candidates, _ = optimize_acqf(\n", + " acq_function=pr_acq_func,\n", + " bounds=standard_bounds,\n", + " q=1,\n", + " num_restarts=NUM_RESTARTS,\n", + " raw_samples=RAW_SAMPLES, # used for intialization heuristic\n", + " options={\n", + " \"batch_limit\": 5,\n", + " \"maxiter\": 200,\n", + " \"rel_tol\": float(\"-inf\"), # run for a full 200 steps\n", + " },\n", + " # use Adam for Monte Carlo PR\n", + " gen_candidates=gen_candidates_torch if not analytic else gen_candidates_scipy,\n", + " )\n", + " # round the resulting candidates and take the best across restarts\n", + " new_x = pr_acq_func.sample_candidates(X=candidates.detach())\n", + " # observe new values\n", + " exact_obj = eval_problem(new_x)\n", + " return new_x, exact_obj\n", + "\n", + "\n", + "def update_random_observations(best_random):\n", + " \"\"\"Simulates a random policy by taking a the current list of best values observed randomly,\n", + " drawing a new random point, observing its value, and updating the list.\n", + " \"\"\"\n", + " rand_x = torch.rand(1, base_function.dim, **tkwargs)\n", + " next_random_best = eval_problem(rand_x).max().item()\n", + " best_random.append(max(best_random[-1], next_random_best))\n", + " return best_random" + ], + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "b9344aeb-149e-46a5-9c17-dfbd9ae4727c", + "showInput": false, + "collapsed": false + }, + "source": [ + "### Perform Bayesian Optimization loop\n", + "\n", + "*Note*: Running this may take a little while." + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "59cbf551-b4a8-4c90-aa4b-b5a3faf7ac54", + "collapsed": false, + "requestMsgId": "ad801040-9685-4c66-b6c0-4124cddfddf0", + "customOutput": null, + "executionStartTime": 1669843261955, + "executionStopTime": 1669843810811, + "showInput": true + }, + "source": [ + "import time\n", + "import warnings\n", + "\n", + "from botorch import fit_gpytorch_mll\n", + "from botorch.acquisition.analytic import ExpectedImprovement\n", + "from botorch.exceptions import BadInitialCandidatesWarning\n", + "from botorch.utils.transforms import standardize\n", + "\n", + "\n", + "warnings.filterwarnings(\"ignore\", category=BadInitialCandidatesWarning)\n", + "warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n", + "\n", + "torch.manual_seed(0)\n", + "N_TRIALS = 1\n", + "N_BATCH = 15 if not SMOKE_TEST else 2\n", + "\n", + "verbose = True\n", + "\n", + "(\n", + " best_observed_all_pr,\n", + " best_observed_all_pr_analytic,\n", + " best_observed_all_cont_relax,\n", + " best_random_all,\n", + ") = ([], [], [], [])\n", + "\n", + "\n", + "# average over multiple trials\n", + "for trial in range(1, N_TRIALS + 1):\n", + "\n", + " print(f\"\\nTrial {trial:>2} of {N_TRIALS} \", end=\"\")\n", + " (\n", + " best_observed_pr,\n", + " best_observed_pr_analytic,\n", + " best_observed_cont_relax,\n", + " best_random,\n", + " ) = ([], [], [], [])\n", + "\n", + " # call helper functions to generate initial training data and initialize model\n", + " train_x_pr, train_obj_pr = generate_initial_data(n=20)\n", + " best_observed_value_pr = train_obj_pr.max().item()\n", + " stdized_train_obj_pr = standardize(train_obj_pr)\n", + " mll_pr, model_pr = initialize_model(train_x_pr, stdized_train_obj_pr)\n", + "\n", + " train_x_pr_analytic, train_obj_pr_analytic, stdized_train_obj_pr_analytic = (\n", + " train_x_pr,\n", + " train_obj_pr,\n", + " stdized_train_obj_pr,\n", + " )\n", + " best_observed_value_pr_analytic = best_observed_value_pr\n", + " mll_pr_analytic, model_pr_analytic = initialize_model(\n", + " train_x_pr_analytic,\n", + " stdized_train_obj_pr_analytic,\n", + " )\n", + "\n", + " train_x_cont_relax, train_obj_cont_relax, stdized_train_obj_cont_relax = (\n", + " train_x_pr,\n", + " train_obj_pr,\n", + " stdized_train_obj_pr,\n", + " )\n", + " best_observed_value_cont_relax = best_observed_value_pr\n", + " mll_cont_relax, model_cont_relax = initialize_model(\n", + " train_x_cont_relax,\n", + " stdized_train_obj_cont_relax,\n", + " )\n", + "\n", + " best_observed_pr.append(best_observed_value_pr)\n", + " best_observed_pr_analytic.append(best_observed_value_pr_analytic)\n", + " best_observed_cont_relax.append(best_observed_value_cont_relax)\n", + " best_random.append(best_observed_value_pr)\n", + "\n", + " # run N_BATCH rounds of BayesOpt after the initial random batch\n", + " for iteration in range(1, N_BATCH + 1):\n", + "\n", + " t0 = time.monotonic()\n", + "\n", + " # fit the models\n", + " fit_gpytorch_mll(mll_pr)\n", + " fit_gpytorch_mll(mll_pr_analytic)\n", + " fit_gpytorch_mll(mll_cont_relax)\n", + "\n", + " # for best_f, we use the best observed values\n", + " ei_pr = ExpectedImprovement(\n", + " model=model_pr,\n", + " best_f=stdized_train_obj_pr.max(),\n", + " )\n", + "\n", + " ei_pr_analytic = ExpectedImprovement(\n", + " model=model_pr_analytic,\n", + " best_f=stdized_train_obj_pr_analytic.max(),\n", + " )\n", + "\n", + " ei_cont_relax = ExpectedImprovement(\n", + " model=model_cont_relax,\n", + " best_f=stdized_train_obj_cont_relax.max(),\n", + " )\n", + "\n", + " # optimize and get new observation\n", + " new_x_pr, new_obj_pr = optimize_acqf_pr_and_get_observation(\n", + " ei_pr, analytic=False\n", + " )\n", + " new_x_pr_analytic, new_obj_pr_analytic = optimize_acqf_pr_and_get_observation(\n", + " ei_pr_analytic, analytic=True\n", + " )\n", + " (\n", + " new_x_cont_relax,\n", + " new_obj_cont_relax,\n", + " ) = optimize_acqf_cont_relax_and_get_observation(ei_cont_relax)\n", + "\n", + " # update training points\n", + " train_x_pr = torch.cat([train_x_pr, new_x_pr])\n", + " train_obj_pr = torch.cat([train_obj_pr, new_obj_pr])\n", + " stdized_train_obj_pr = standardize(train_obj_pr)\n", + "\n", + " train_x_pr_analytic = torch.cat([train_x_pr_analytic, new_x_pr_analytic])\n", + " train_obj_pr_analytic = torch.cat([train_obj_pr_analytic, new_obj_pr_analytic])\n", + " stdized_train_obj_pr_analytic = standardize(train_obj_pr_analytic)\n", + "\n", + " train_x_cont_relax = torch.cat([train_x_cont_relax, new_x_cont_relax])\n", + " train_obj_cont_relax = torch.cat([train_obj_cont_relax, new_obj_cont_relax])\n", + " stdized_train_obj_cont_relax = standardize(train_obj_cont_relax)\n", + "\n", + " # update progress\n", + " best_random = update_random_observations(best_random)\n", + " best_value_pr_analytic = train_obj_pr.max().item()\n", + " best_value_pr = train_obj_pr_analytic.max().item()\n", + " best_value_cont_relax = train_obj_cont_relax.max().item()\n", + " best_observed_pr.append(best_value_pr)\n", + " best_observed_pr_analytic.append(best_value_pr_analytic)\n", + " best_observed_cont_relax.append(best_value_cont_relax)\n", + "\n", + " # reinitialize the models so they are ready for fitting on next iteration\n", + " # use the current state dict to speed up fitting\n", + " mll_pr, model_pr = initialize_model(\n", + " train_x_pr,\n", + " stdized_train_obj_pr,\n", + " )\n", + " mll_pr_analytic, model_pr_analytic = initialize_model(\n", + " train_x_pr_analytic,\n", + " stdized_train_obj_pr_analytic,\n", + " )\n", + " mll_cont_relax, model_cont_relax = initialize_model(\n", + " train_x_cont_relax,\n", + " stdized_train_obj_cont_relax,\n", + " )\n", + "\n", + " t1 = time.monotonic()\n", + "\n", + " if verbose:\n", + " print(\n", + " f\"\\nBatch {iteration:>2}: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = \"\n", + " f\"({max(best_random):>4.2f}, {best_value_cont_relax:>4.2f}, {best_value_pr:>4.2f}, {best_value_pr_analytic:>4.2f}), \"\n", + " f\"time = {t1-t0:>4.2f}.\",\n", + " end=\"\",\n", + " )\n", + " else:\n", + " print(\".\", end=\"\")\n", + "\n", + " best_observed_all_pr.append(best_observed_pr)\n", + " best_observed_all_pr_analytic.append(best_observed_pr_analytic)\n", + " best_observed_all_cont_relax.append(best_observed_cont_relax)\n", + " best_random_all.append(best_random)" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\nTrial 1 of 1 ", + "\nBatch 1: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.98, -1.90, -1.98), time = 45.77.", + "\nBatch 2: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.98, -1.90, -1.98), time = 64.38.", + "\nBatch 3: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.96, -1.74, -1.98), time = 38.69.", + "\nBatch 4: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.96, -1.20, -1.84), time = 36.71.", + "\nBatch 5: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.96, -1.20, -1.51), time = 35.81.", + "\nBatch 6: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.96, -1.20, -1.51), time = 33.21.", + "\nBatch 7: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.96, -1.20, -1.51), time = 34.27.", + "\nBatch 8: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.83, -1.20, -1.51), time = 36.21.", + "\nBatch 9: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.56, -1.20, -1.51), time = 37.24.", + "\nBatch 10: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.56, -1.20, -1.51), time = 32.22.", + "\nBatch 11: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.56, -0.48, -1.09), time = 23.55.", + "\nBatch 12: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.56, -0.48, -1.09), time = 34.03.", + "\nBatch 13: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.56, -0.48, -1.09), time = 33.57.", + "\nBatch 14: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.56, -0.48, -1.09), time = 41.31.", + "\nBatch 15: best_value (random, Cont. Relax., PR (MC), PR (Analytic)) = (-1.98, -1.56, -0.48, -0.79), time = 21.94." + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "587be90e-69f5-4b33-ad40-1aafe38d305c", + "showInput": false, + "collapsed": false + }, + "source": [ + "#### Plot the results\n", + "The plot below shows the best objective value observed at each step of the optimization for each of the algorithms." + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "cd365490-cf84-4456-a033-3b58879a0293", + "collapsed": false, + "requestMsgId": "03a1289e-6f40-4066-aced-875fb75dac36", + "customOutput": null, + "executionStartTime": 1669843815143, + "executionStopTime": 1669843815915 + }, + "source": [ + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "\n", + "%matplotlib inline\n", + "\n", + "\n", + "def ci(y):\n", + " return 1.96 * y.std(axis=0) / np.sqrt(N_TRIALS)\n", + "\n", + "\n", + "iters = np.arange(N_BATCH + 1)\n", + "y_cont_relax = np.asarray(best_observed_all_cont_relax)\n", + "y_pr = np.asarray(best_observed_all_pr)\n", + "y_pr_analytic = np.asarray(best_observed_all_pr_analytic)\n", + "y_rnd = np.asarray(best_random_all)\n", + "\n", + "fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n", + "ax.errorbar(iters, y_rnd.mean(axis=0), yerr=ci(y_rnd), label=\"Random\", linewidth=1.5)\n", + "ax.errorbar(\n", + " iters,\n", + " y_cont_relax.mean(axis=0),\n", + " yerr=ci(y_cont_relax),\n", + " label=\"Cont. Relax.\",\n", + " linewidth=1.5,\n", + ")\n", + "ax.errorbar(iters, y_pr.mean(axis=0), yerr=ci(y_pr), label=\"PR (MC)\", linewidth=1.5)\n", + "ax.errorbar(iters, y_pr_analytic.mean(axis=0), yerr=ci(y_pr_analytic), label=\"PR (Analytic)\", linewidth=1.5)\n", + "ax.set(\n", + " xlabel=\"number of observations (beyond initial points)\",\n", + " ylabel=\"best objective value\",\n", + ")\n", + "ax.legend(loc=\"lower right\")" + ], + "execution_count": 9, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "" + }, + "metadata": { + "bento_obj_id": "139712286187136" + }, + "execution_count": 9 + }, + { + "output_type": "display_data", + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAF6CAYAAACEHlvDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOzdeXxcVf3/8ddM9jRJm7XpTmlpy1p2LLvIsAwIDoqsYkU2BWUT8CcoigIiX3dBEHFDENQvAwIDOGUtBUoLtIWWLkn3NmmTTPZkMtv9/ZEbvqEk6aSZmXtn8n4+HnkkuXPn3s9p0plPzvmccxyGYSAiIiKjl9PqAERERMRaSgZERERGOSUDIiIio5ySARERkVFOyYCIiMgop2RARERklBvNyYCR6I9AQ33Cr2mHD7UrvT7UrvT6ULvS6yMD2jWg0ZwMJFwsGrM6hKRQu9KL2pVe1K70kqntUjIgIiIyyikZEBERGeWUDIiIiIxySgZERERGOSUDIiIio5ySARERkVFOyYCIiMgop2RARERklFMyICIiMsopGRARERnllAyIiIiMctlWByAiIskXCcbY8UEPhjHoXjUJ1docI7i5KyX3SqVUt6tyvzzyirKSfh8lAyIio8AHT7Sy7JHWFN91Z4rvlyqpa9cZv6mmaj8lAyIikgANa0KMnZrDcTdVpOR+LU0NjCuvTMm9UinV7Ro3LScl91EyICIyCgRqQkw6PJ/KffNScj9HvZOK6tTcK5UytV0qIBQRyXBdgQjdgShlM3KtDkVsSsmAiEiGC9SEACifmXl/0UpiKBkQEclwfcmAegZkMEoGREQyXFNNiOIJ2eQW6SVfBqbfDBGRDBeoDalXQIakZEBEJIOFu2K0bYtQNlPJgAxOyYCISAYL1PYVDyoZkMEpGRARyWBNfcWDSgZkCEoGREQyWKA2RP44J4XlyV/SVtKXkgERkQzWVNNbPOhwOKwORWxMyYCISIaKhg1aNoY0RCC7pWRARCRDtW4OE4uoeFB2T8mAiEiGaqrpAS1DLHFQMiAikqGaakJk5zsonqgNamVoSgZERDJUoDZE6d65OLNUPChDUzIgIpKBjJhBoCakegGJi5IBEZEM1F4fIdxlKBmQuNhqIMnl9hQA9wLnAiXASuBmv8/78m6edylwCzAN2A781u/z/jJ1kYuI2Iu2LZbhsFvPwP3AycBngQrgH8BzLrdnn8Ge4HJ7zgN+AFwCjAW+BVzlcnuOSG3oIiL20VQTwuGEcdNzrA5F0oBtegZcbk8ZcDHwJb/Pu8o8/HOX23MRcBVw4yBPvR24ye/zLja/f878EBEZtQI1IcZNyyE7125/84kd2SYZAA4141myy/F3gM8M9ASX2zMB2BfIcrk9S4HZwHrgTr/P+8/UhC0iYj9NtSEmHpJvdRiSJlKWDLjcnmygaIhTqszPTbscbwSqB3nOVPPzVcAFZr3AZcATLrdnh9/nfW2wmwUa6olFY8NrxG5EImEa67cn9Jp2oHalF7UrvSSjXT2tBt1NUfLGBy37N9PPy54qqicOeDyVPQMnAv4hHr9okOMOwBjksb74f+T3edeZX//aHFr4GjBoMlBWOVh+seca67cP+g+dztSu9KJ2pZdktGvblm5gB1MPqaCiuiCh146Xfl7pJWXJgN/nXWC+sQ/I5facZH5ZCWzt91AVUD/I0xrNz827HF8PTBhZxCIi6alvGWJtUCTxslNlybtACJi3y/GjgUWDPKfGTAiO2uX4TGBDkuIUEbG1ppoQRdXZ5BVlWR2KpAnbFBD6fd5Wl9vzMHC3y+1ZBWwyZxBMM6ccQm8Pwt3AVL/Pe5Hf54263J5fALe73J73gRXA5cAh5mcRkVEnUKtti2V4bJMMmK4H7gFeAYqBZcApfp93U79zJpgJQp+fmj0c/wbKgdXA5/0+7/sWxC8iYqlwd4y2bRFmnDxUvbbIJ9kqGfD7vD3AdebHYOfM3+V7A7jT/BARGdUCtSEwtPKgDI+dagZERGSEArW9yxBrTwIZDiUDIiIZpKkmRN5YJ4UVKh6U+CkZEBHJIIGaEOUzcnE4Bp3JLfIpSgZERDJELGLQvFEzCWT4lAyIiGSIls1hYmHVC8jwKRkQEckQTTW9xYPqGZDhUjIgIpIhAjU9ZOc7KJmUY3UokmaUDIiIZIhATYjS6bk4s1Q8KMOjZEBEJAMYhkGgNqR6AdkjSgZERDJAR12EUKehlQdljygZEBHJAE21Kh6UPadkQEQkAwRqQjicUDpdxYMyfEoGREQyQFNNiLFTc8jO08u6DJ9+a0REMkCgRsWDsueUDIiIpLnu5ihdTVHVC8geUzIgIpLmtG2xjJSSARGRNPfxMsR7KxmQPaNkQEQkzQVqQxSNzyKvJMvqUCRNKRkQEUlzgZoeymbmWR2GpDElAyIiaSzcHaN1a0QrD8qIKBkQEUljgfUhMFQ8KCOjZEBEJI0FarQMsYyckgERkTQWqAmRV+JkTKWKB2XPKRkQEUljTbUhymbk4nA4rA5F0piSARGRNBWLGLRs0DLEMnJKBkRE0lTL5jDRsOoFZOSUDIiIpKm+4kH1DMhIKRkQEUlTTTUhsvIclEzOsToUSXPZVgfQn8vtKQDuBc4FSoCVwM1+n/flIZ4zB/gZcDSQA6wG7vD7vM+lNnoRkdQK1IYom56DM0vFgzIydusZuB84GfgsUAH8A3jO5fbsM9DJLrfHAbwItAMzzOc8CjxlJgkiIhnJMIzeZEBDBJIAtkkGXG5PGXAxcIvf513l93k7/T7vz4GPgKsGeVoVMBV4zO/ztvp93jDwsNnjMTfFTRARSZmO+gihjpj2JMhQ0c526v5+P+FAY0ruZ5tkADjUfBNfssvxd4DPDPQEv8+7A3gNuNTl9lS43J484EqgCXg1NWGLiKReoFbFg5ks4H+awItPEu1oS8n9UlYz4HJ7soGiIU6pMj837XK8Eage4nlfBp4HGgDDPP9LZqIwqEBDPbFoLP4GxCESCdNYvz2h17QDtSu9qF3pZU/btWVZBBwQK2yksd5+NQP6ee05I9RDwwv/S+6cuXTk5tORwPtVVE8c8HgqCwhPBPxDPH7RIMcd5pv8p7jcnlyzZmA1cDrQAXwVeMbl9szz+7wfDnazssqh8os901i/fdB/6HSmdqUXtSu97Gm7gvU7GDc1wvipk5IS10jp57XnmvxPYXS2M/FL8xmTon/DlCUDfp93gfnGPiCX23OS+WUlsLXfQ1VA/SBPOxk4GDjd7/PuNI/93uX2fAO4FLghcS0QEbGPptoQ1QflWx2GJJgRidDk+xeFs/ZnzOwDU3ZfO9UMvAuEgHm7HD8aWLSb5+6aZGQP1psgIpLugq1RuhqiqhfIQK1vv0K4cQcVZ16Q0vvaZp0Bv8/b6nJ7Hgbudrk9q4BNwI3ANHPKIfT2INwNTPX7vBcBb5q9Bj91uT3XAZ3mjITZZiGhiEjGaerbtniGkoFMYsRiND77BHmT96Jo7pEpvbedegYArgeeBV4xCwJPA07x+7yb+p0zwUwQ8Pu8LcCpQDmw1nzON4Bz/T7vQuuaISKSPFqGODN1LFtMz7aNVJx5Pg5nat+ebdMzQO+bew9wnfkx2Dnzd/l+BXBmSgIUEbGBQE2IMVVZ5JVkWR2KJIhhGDQ88w9yKqoZ+5nPpvz+dusZEBGR3WiqDWmIIMN0rfmA7ppVlLvPxZGV+iRPyYCISBoJd8do3RLWEEGGaXzmcbKKx1F6/KmW3F/JgIhIGmneEAIDLUOcQYKbaulY8Q7lp3pw5lkzXVTJgIhIGlHxYOZpfO5xnPkFlJ18lmUxKBkQEUkjTTUhcoudjKlS8WAmCO3cTuvbr1F60ufJGlNsWRxKBkRE0kigJkT5zFwcDvvtRyDD1/jcP3FkZVF+2jmWxqFkQEQkTcSiBs0bQpRpiCAjhFsCtCx8kXHHusgprbA0FiUDIiJponVzmGgYyjWtMCMEXnwSIxKl/IwvWx2KkgERkXTx8TLE6hlIe9GuDgIv/YeSI48jr3qy1eEoGRARSReB2hBZuQ7GTsmxOhQZocBLzxDr7qLizPOtDgWUDIiIpI+mmhCl03NwZql4MJ3FQj0EXnySMQceRsFe+1gdDigZEBFJD4ZhEKhR8WAmaFn4XyKtzVSmeJvioSgZEBFJAx07IoQ6YlpsKM0Z0SiNvn9SsPccCveda3U4H1MyICKSBgK1fSsPahnidNb2zmuEd9ZR8fkLbLVWRNzJgMvtcbrcnhNdbs/8fscKkxaZiIh8LFATwuGE0ukqHkxXvdsUP07uxKkUHzrP6nA+Ia5kwOX2TAM+BBYAD5rH9gI2uNyeA5IepYjIKNdUE6Jkcg7Z+erQTVcdK5bQs2U9FWech8Npr59jvNH8AlgMVAAxAL/PuxH4g/mYiIgkUd8yxJK+Gp95nOyySsYefZLVoXxKvMnAPOA6v8/bAhj9jt8NHJmk2EREBAi2RulsiFKmlQfTVtfalXStWUGF+1yc2fYb6ok3GSgGugc4ngto6ywRkSTStsXpr/HZx8kqKqb0hNOtDmVA8SYDi4Eb+x9wuT1F/YYPREQkSZpqtQxxOgtu2UD7+29R5vLgzC+wOpwBZcd53neAF1xuz5VAnsvtWQbMADqB05Ico4jIqBaoCVFYmUX+WHXEpqPG557AkZtPmetsq0MZVFw9A36fdxkwE7gH+A3wkpkg7GM+JiIiSdKk4sG0FWrcQetbL1P6WTfZxWOtDmdQ8fYM4Pd5O4DfJzccERHpLxKM0bY1zPQTtKxLOmry/ROHw0mF+1yrQxlSXMmAy+15eahr+H3e4xMXkoiI9GneEMaIqV4gHUVam2l+9XnGHvM5csoqrQ5nSPH2DGzb5XsnMBU4EPhjEuISERGgqaYHQNMK01CT/ymMSJgK95etDmW34koG/D7vVwY67nJ73IAr4VGJiAiY9QK5xU6Kxsc9qis2EO3uIuB/muLDjiFv0jSrw9mtEa2H6Pd5fUoGRESSJ1AbomxGrq02tZHda37lOWJdHVSceb7VocRlRMmAy+2ZCpQnLhwREekTixo0rw9TriGCtBILh2h64d+M2e8QCmfMsTqcuMRbQPj6AIdzgf2AFxIfloiItG4JEw0ZKh5MM62LFhBpbmLS5TdbHUrc4h2EWr/LngSYyxM/CjycyIBcbs904M/ACcB0c0Okoc4vAO4FzgVKgJXAzX6fd6gZECIitqdliNOPEYvS+NwT5O+1D2MOONTqcOIWbwHh/OSHAi63xwM8MMzehvvNjZQ+C2wCrgKec7k9B/l93nVJDFdEJKmaakNk5cDYKfbb2EYG1rZ0EaH6bUy+5vtpVecxaDLgcnsujfcifp/3TwmKpww4HpgCXLK7k11uTxlwMfAlv8+7yjz8c5fbc5GZFNy4m0uIiNhWoCZE6fRcnNnp86YymhmGQeMz/yC3ehIlRxxrdTjDMlTPQLzrBxhAQpIBv8/7ML1v8lPifMqhZhuW7HL8HeAziYhJRMQKhmHQVBNir+O08mC66PzwPYIb1zHx6zfgcKbXPhKDJgN+nzeumQYutyc/zvOygaIh7tcSz3V2UWV+btrleCNQPdQTAw31xKKxPbjl4CKRMI312xN6TTtQu9KL2pVeBmtXV6NBqD1GXlV3WrZ7tP28AJqf/CvOknFEZu5v27ZXVE8c8PiIVrEwi/dqgYGv/kknAv6hruX3eYMjiacfxwAFj59QVjlkrrBHGuu3D/oPnc7UrvSidqWXwdq1qbYL2MnUQyupqI7rby5bGW0/r67a1eyoWcX486+gYrL9FxnaVbxTCyvMiv2jgP6/lWVAazzX8Pu8C8w36USqNz9XAlv7Ha/q95iISNoJ1PSAA0qnayZBOmh89nGchUWUnnSm1aHskXgXHfqtOT7/jFnc9w9gLbDK/IvfKu8CIXM2QX9HA4ssiklEZMSaakKMnZJDTsGI1oaTFOjZvpn2dxdR5jqbrIL0rPGI97fsJMDl93lvAcJ+n/dWv897GvAycE6SY/wEl9tzt8vteZTe3oZWc52Du11uz/4ut6fI5fbcDkwzpxyKiKSlvmWIxf4an3sCR04u5ad4rA5lj8WbDBT4fd6d5tdRsxgQ4C7g24kKxuX2rHG5PUHAZx5a43J7gi6356F+p00w3+z7XA88C7wCNACnAaf4fd5NiYpLRCSVgq1ROndGtdhQGgg37aR10UuUnnAa2SXjrA5nj8VbQLjG5fZcD/wa2Gz2BvzTHKsvTVQwfp93dhznzN/l+x7gOvNDRCTtBWp7Vx7UMsT21/TC/2IYMcpPP9fqUEYk3p6BHwE/NacG/hl41OX2LAHeH2qGgIiIDF9fMqANiuwt0t5K8yvPMXbeSeQmYYZaKsWVDPh93meBqX6ft83v8/4PMB942xwm2O1KgSIiEr+mmhCFFVnkj0uvhWtGm4D/aWI9QSrOOM/qUEYs3qmFPwL+2ve93+d91NykSEREEixQE1K9gM3Fgt0E/F6KD5lH/pTpVoczYvHWDFwC3OZye94E/gL8y+/ztiU5NhGRUSfSE6N1S5hpWobY1ppfe55oRzsVZ55vdSgJEe8wwXTgWHNe/x1Ancvteczl9pzqcnu0g4aISII0bwhjxNC0QhuLRcI0Pf8vCmcfROGs/a0OJyHiXs3C7/O+5fd5rwMmA583Vx78E7AluSGKiIweTTVm8aCGCWyr9c2XCTc1UPH5zOgVYDjJQB+/z2sA7UAbEABKkhOaiMjoE6jpIXeMg6LqEW0dI0lixGI0PvcEeVP2puigI6wOJ2Hi/m1zuT2HAV8GzjV7BxYAdwPe5IYoIjJ6BGp6Vx50ODQCa0c9q94ntH0zk7/5vYz6GcU7m6AW2MtcV+A3wGP9ViQUEZEEiEUNAhvCzD6j2OpQZACGYdD58jPkVE2g5MgTrA4noeLtGXgceMTv865OcjwiIqNW29Yw0R4jKfUCre+8Tv3f7wdjyN3dEyYWi9LkzLB1EowYkdZmJsy/FkdWZrUtrmTA7/PemvxQRERGt77iwWQsQxxY8DRgUHzwUQm/9kCC3V3kp+kOfkPpMWDccadaHUbCqUJFRMQmAjUhsnJg3NSchF430tZC1+oPqDz7Qqq+OD+OZ4xcY/12KqonpuReqdRYvx1nbubN9NBG2SIiNtFUE2Lc9Fyc2YktTGt7dxEYMUoOPy6h15XMoWRARMQGDMPoXYY4CYsNtS99g9yqieRN3Tvh15bMEHcy4HJ7nC6350SX2zO/37HMGxASEbFAZ0OUnvZYwusFop0ddK58n+LDj82oqXCSWHElAy63Zxrwobm2wIPmsb2ADS6354CkRykikuECSSoebF/2NkY0QskRGiKQwcXbM/ALYDFQAcTonWGwEfiD+ZiIiIxAU00IHFC2d2KTgbYlC8kuraBg79kJva5klniTgXnAdX6ftwXoP0n1buDIJMUmIjJqBGpDlEzKJqcgcaVcsWA3HSuWUHL4sTicKhGTwcX721EMdA9wPBfIrJUXREQsEKjpSfhiQ+0rlmCEQ5QccWxCryuZJ95kYDFwY/8DLrenqN/wgYiI7KGetigdO6KUzcxL6HXbliwkq3gchbMPTOh1JfPEu+jQd4AXXG7PlUCey+1ZBswAOoHTkhyjiEhGa6pN/LbFsVCIjmVvU/KZE3Fk2rLAknBx9Qz4fd5lwEzgHnOjopfMBGEf8zEREdlDATMZKEvgGgOdK98jFuzWLAKJS7y7Fv4I+Kvf5/198kMSERldAjUhCsuzKChN3F/wbUsX4iwcw5j9DknYNSVzxTtMcAlwm8vteRP4C/Avv8/bluTYRERGhaaaUELXFzAiEdrfe5PiQ+bhzE7sPgeSmeIdJpgOHAu8C9wB1Lncnsdcbs+pLrdHS1qJiOyhaMigdXM4sUMEa1YQ7WjXXgQSt7gnnvp93rf8Pu91wGTg80Ar8CdgS3JDFBHJXO1bDIxYYosH25YsxJGbT9GBhyXsmpLZhr0Khd/nNYB2oA0IACXJCU1EJPO1bupdxy1RyYARi9G+dBHFc4/EmZefkGtK5ou3ZgCX23MY8GXgXLN3YIG5AqE3uSGKiGSu1k0GOYUOiqrjfjkeUnfNKiKtAS00JMMS72yCWmAv4H1zauFjfp93Z/LDExHJbG2bDMpm5uJwJqb8qm3JQhzZORQdfFRCriejQ7yp6OPAI36fd3WS48Hl9kwH/gycAEw3N0Qa6vzx5voHpwEFwErge36f99VkxyoiMhKxqEHbFoPZZyRoiMAwaFv6BmMOOIysgjEJuaaMDoMmAy635wS/z/ua+e1LwESX2zNxoHP9Pu/LiQjG5fZ4gAeAF4bxtKeBZuBgoAW4HXjW5fbM8vu82xMRl4hIMrRtDRPtIWHLEAc3riPcuINKz1cScj0ZPYbqGXjB/Esbsz7AAAbqxzISuFlRGXA8MMVc22BILrenxOwJuNfv89abx+4Bvgt8BngyQXFJholFDHraDYIFUatDSTi1K33s+LAHgPIETStsW/oGOJ0UHzIvIdeT0WOoZKD/5tfTUxALfp/3YXrf0KfEeX4b8PVdDu9tfh6yVyDQUE8sGtvTUAcUiYRprM+8zohMbNdrt4Zp22Rk8MxYtStdOLMhktdAY/3IagYMw6D5rVfI3XsOLZ2d0NmZsBj3RCa+bpAB7aqoHrCDf/BkwO/zbu737Q1+n/faXc9xuT1jgd/G+Vd8NlA0xP1adneNOO5RYtYbPOf3ed8e6tyyyuqR3u5TGuu3D/oPnc4yrV0dOyK0bdrKpKOdTD50nNXhJFxnWytjSsZaHUbCZWq7HEVtVE2eNOLrBLduZGdDHVXucymzwf/XTHvd6JOp7RqygNDl9owzu+6vcLk9vxpgmGA/4IvxJAPAiYB/iHsV+H3eYNyRf/r504BngZ3ABXt6Hcl8dct7f81mnulk7yMzb5mMxvoOKqrVrnTRWN+RkOu0LX0DHA6KDz8mIdeT0WV3swkuAn5lLk5UM8DjDrO4cLf8Pu+CQWoORszl9hxhJgJe4Ft+nzecjPtIZqhf1k1eiZPiyVpJWzJH+5KFFO6zHznjyq0ORdLQkMmA3+e9z+X2PArsAE4Z4JROwNItjF1uzwFmseMdfp/311bGIumhfnmQ6oPycTgjVocikhChHdsJbq5l/IVXWR2KpKndLkdsjuUfbE4zXOz3eV8zv14OvOv3eVP6iupye+42ExRcbk8W8FfgfiUCEo/2ujAdO6JUH6xlWiVztC1dCEDJ4Vp1UPZMvIsOhVxuzxrgVuDf5rGvA1e53J7T/T7vQEMIw2beY1q/JGWNy+0xzAWPLjePTTDPAZgHHAoc4HJ7btrlcv2fIwJA3bLeeoEJB+cTJTFjtSJWa1v6BvnTZ5GbhMJoGR3iTQZ+A7wNvNbv2F+AmeZsgtMTEYzf550dxznz+339RrLqECQz1S8Pkj/OybhpOTTtsDoakZELBxrorvmIqnMvtToUSWPx7lp4NHCl3+dt6Dvg93mbgBvMxX1EbM8wDOr66gUcyiElM7QtfQOAkiOOszoUSWPxJgNRYKAJvlVAYlfuEUmS9u0RuhqiTFC9gGSQtqVvkDdpGnkT4lqrTWRA8Q4TPA086XJ7fgpsMLvm5wD/z3xMxPb66gVUPCiZItLWQtfqD6g8+0KrQ5E0F28ycK1ZN/Ck2ZvgMHsL/g5cl+QYRRKiflmQgrIsxk7JsToUkYRof+9NMGIUaxaBjFBcyYDf5+0ALnW5Pdf2W/t/vd/nbU9ueCKJYRgGdSuCVM9VvYBkjrYlC8mpmkD+1BlWhyJpLt6aAVxujxM4DDjE7/Mu9/u87S63pzC54YkkRtvWCN1NUSbM1RCBZIZoZwedK9+n5PDjlODKiMWVDJjr/n9obmX8oHlsL2CDuQKgiK3VLesG1QtIBmlf9jZGNELJERoikJGLt2fgF8BioKJv9oDf590I/MF8TMTW6pcHKSzPomRSvGUyIvbWtmQh2aUVFOw9x+pQJAPEmwzMA64zlyY2+h2/GzgySbGJJIRhGL37ERysegHJDLFgNx0rllBy+DE4nHGP9ooMKt7fomKge4DjuUBWgmMSSajWzWG6m2NaX0AyRvuKJRjhkBYakoSJNxlYDNzY/4DL7SnqN3wgYlsfry+g4kHJEO1LF5JVPJbCWQdaHYpkiHgHUL8DvOBye64E8lxuzzJghrmF8WlJjlFkROqXBxlTmUXxBNULSPqLhUO0v7+Yks+cgCNLHbOSGHH1DPh93mXmpkT3mIsPvWQmCPuYj4nYkhEz9yNQvYBkiM6V7xELdlFyuIYIJHHi/lPJXHjo98kNRySxmjeG6WlVvYBkjrYlC3EWjmHM/odYHYpkkEGTAZfb4/f7vC7z69d3cx0DaAR+7fd5d3euSMrUL1e9gGQOIxKh/b03KT5kHs5sLastiTNUz8CGfl/XxnGtvYFHAW2dJbZRtzxIUXU2xdV64ZT017lmBdGOdkq0F4Ek2KDJgN/nvaLf11/b3YVcbk8W0JbI4ERGwogZ7FgeZOrRWjVbMkPbkoU4cvMpOvBwq0ORDBN3zYDL7TkVOAeYCgSBzcCjfp/3HXoThigwJqnRigxDYH2InvaYhggkIxixGO1LF1E89wicefqdlsSKd2+Ca4DnzdUGu8wagROBt1xuz+XJD1Nk+D6uF1DxoGSA7ppVRFoDWmhIkiLenoFrgQv9Pu/j/Q+63J6LgduBh5ITnsieq1sWpHhiNkVVWl9A0l/bkoU4snMoOvgoq0ORDBTvCoQTgH8PcPxxYGKCYxIZsVjUYMeKoLYsloxgGAZtS99gzAGHklWg0VhJvHiTgbeBuQMcPwBYmuCYREYssD5EqNPQEIFkhODGdYQbd2iIQJJmqHUGju/37SPAX11uz5+B1UAU2Bf4GnBXakIViV+9uR+BegYkE7QtfQOcTooPmWd1KJKhhhpMfXWAY/cOcOxRc7hAxDbqlgUpmZxNYYXqBSS9GYZB25KFjNl3LtnFY60ORzLUUOsMxDvTQK+2YiuxqMGOD4JM/6zGViX99WzfTKhuC+WnfMHqUCSDDeuN3OX2HGbuVhgD1vp93hV+nzeSvPBEhq9pXYhwl6H9CCQjtC1ZCA4HxVp1UJIormTA5ch2ZUsAACAASURBVPaMB/xmwWAfw+X2LAbO8Pu8zckLUWR4tB+BZJL2JQspmLkfOePKrQ5FMli8swl+AdQDRwHjgDLgOKDH3NZYxDbqlgUZOzWHwjKNYEl6C+3YTnBzrWYRSNLF+2p5PHCY3+fd2e/Ymy635yvAokQG5HJ7pgN/Bk4Apvt93o3DeO4xwOvAj/0+7w8TGZekh1jEYMeHQWaeXGR1KCIj1rZ0IQAlhx1jdSiS4eLtGSgaZBOiRqAiUcG43B6PuabBpj14boGZRHQkKh5JP41re4h0a30ByQxtS98gf699yK2aYHUokuHiTQZWAlcNcPybwEcJjKfM7IV4ZA+ee5e5BsL7CYxH0kzf+gLVBykZkPQWDjTQXfORhggkJeIdJvgesMDl9lxhvvk7gP2AvYCzExWM3+d9mN6/8qcM53kut+dY4CvAgcA/EhWPpJ+65UHG7ZVDQWmW1aGIjEjbu70jsCWaRSApEFcy4Pd5X3e5PfsBVwIzzcNPAQ/EO6Zvrkcw6ECu3+dtiTfoXa5baA4PXOf3eetcbk9czws01BOLxvbkloOKRMI01m9P6DXtIF3a1VsvEGbK8c644k2Xdg2X2pVeBmtXYNFLZI2fRLszm/Y0bPdo+3mli4rqgbcTirvc2u/z1gA3jSCGE83piQNyuT0Ffp83uAfXvQtY5fd5/z6cJ5VVVu/BrYbWWL990H/odJYu7drxYZBoTz3Tjy6nonr3Cw6lS7uGS+1KLwO1K9LWwo71a6g464K0bfNo+nllgpTNvfL7vAvM4YWEMYcHLgQOSuR1JT3VqV5AMkT7e2+CEVO9gKRMuk/E/jpQAnzYb3hgLHCky+05y+/zHmpteJJK9cuDlO6dQ/5Y1QtIemtbspCcymryp86wOhQZJeKdTWAbLrfnbpfb86j57Q1mDcPB/T6WAg8AbotDlRSKhgx2ruzREsSS9qKdHXSufJ+SI47D4UhoZ6rIoOJdjvhwv8+7dIDjBcDn/D7vs4kIxuX2rAGm9UtS1rjcHgN4xO/zXm4em2Ceg7kMcvMu1+gB2vw+b30iYpL00LC6h2jIoHpugdWhiIxI+7K3MaIRDRFISsU7TPA6UDjA8QLgb+b6ACPm93lnx3HO/N08fmIiYpH0UrcsCA6oPijP6lBERqRt6Rtkl5ZTsPccq0ORUWTIZMDl9nwNuBTIdbk9rw9wyiSgO3nhicSnfnmQshm55BWrXkDSVyzYTceKJZSecBoOZ9qN4koa213PwH/NtQHmAbUDPL4CGNaUPpFEi4RiNKwKMufsEqtDERmR9hVLMEI9lByuIQJJrSGTAb/Puw34rcvtmeD3eb+XurBE4tewqodoGBUPStprX7qQrOKxFM4+0OpQZJSJt2bgBy6353K/z/sQvcMHpwFXAGuB2/0+b09ywxQZXN2yIA4njD9AyYCkr1g4RPv7iyk56gQcWRruktSKd1DqHuB6ehOBqcD/mlX8xwE/TW6IIkOrXx6kfGYuuUUaY5X01bnyPWLBLs0iEEvE++r5ZeAs8+uLgMV+n/frwJeALyYxPpEhRYIxGj7q0ZbFkvbalizEWVDImP0OtjoUGYXiTQZKzb0JAFzAf+itKagDypMXnsjQdq7sIRaB6rlKBiR9GdEo7e+9SfEh83Dm5FodjoxC8dYMNLrcnr2AHuAY4Gr+b6vh1uSGKDK4uuW99QLVByoZkPTVuXo50Y52DRGIZeJNBh4B3gRiwOt+n/cjl9tTBPwV8CU5RpFB1S8PUjErl5xC1QtI+mpbshBHbj5FBx5udSgySsX1Cur3eb9v7gNwb78agZC59sD1yQ1RZGDh7hgNq3uoPlhLEEv6MmIx2pcuonjuETjz1MMl1oh710K/z/s4vUMD2eb3IeDy3T5RJEl2ftiDEYUJqheQNBbeVEOkNUDx4cdaHYqMYvFuVJQN3AxcCVQChS63Z4zZU3CdmRiIpFTd8iCOLKg6QPsRSPrq+XApjuwcig/5jNWhyCgW70Drd81Fhu4F+vbULAQO1zoDYpX6Zd1Uzskjp0D1ApKeDMMg+MFSxux/CFkFY6wOR0axeF9FLwLO8vu8vwMMeocJGoDzzTUIRFIq3BWjcW1IUwol4QzDwIhFU/IR3LCWWHOjZhGI5eKtGZgKfDDA8c1aZ0CssOPDIEYsPeoFjEiEaGd7yu4X62gj0pp5RZXR1gAhB8RCPRjhHmKhEEY41Pt9KEQs3Pc5hBHqMc8Lf/L8UI/5nJB5bIBzwiEwjNQ1zOmk+NCjU3c/kQHEmwxsB44A3tnl+FnAliTEJTKkumVBnNlQtb+96wUMw2D9Hd8muGFtSu/bkNK7pU7jcE52OHHm5uLIycWRm2d+nYczJxdHbi5ZRSVk5+bhzMnpPZ6bh6Pv/KxscDjiuMnIhQqLyS4em5J7iQwm3mTgV8B/XG7Pg0CWy+25FjjUHCK4KckxinxK/bIglfvmkZ1v73qBrtXLCW5YS+lJZ5I/ZXpK7tnR1kpRSea9uXS0t1NSWfWJN25n3xu9+dmRk9v7pp+bhyMrG0eK3tBHorF+u9UhiMSXDPh93vtcbk+juaZAF/ADc8fC+X6f94nkhynyf0IdMZpqQhx0of3f8AL+p8kqKqb6om/gzE1NL0asfjtl1RNTcq9UitVvZ1wGtkvEDoazzsATgN74xXIf1wvYfHOicNNO2t5dRPnp56YsERAR2RNxJwMut+dU4ByzmDAIbAIe8/u8u9YRiCRV3bIgWTlQuZ+932ADLz0DBpR97vNWhyIiMqS4Blxdbs81wPPAkeYwgQF8FnjL5fZoFUJJqbq+eoFc+9YLxEIhml/1UXzIZ8itrLY6HBGRIcXbM3AtcGHfksR9XG7PxcDtwEPJCU/kk3raowRqQxz8lXFWhzKktsWvEm1vpeyUL1gdiojIbsX7p9UE4N8DHH8cUEWPpEz9ih4w7F0vYBgGTf6nyJs0jTH7HWJ1OCIiuxVvz8DbwFzg3V2OHwAsTUJcIgOqX95NVq6Dyjn2rRforvmI4Ia1TPjqt9Niapt8UswwaIpEUna/QDSKEQ6n7H6ponYlxrisLHKcyR8SHTQZcLk9x/f79hHgry6358/AaiAK7At8Dbgr6VGKmOqWBanaP4+sXPu+yQb8T+EsKGTssS6rQ5E9cNfmzXibmlJ70x0ZukyU2jVif5k9mwPHJH/fiqF6Bl4d4Ni9Axx71BwuEEmqYGuU5vVhDplv3w1dwi0B2t55ndKTzyIrP/OWBM50W4JB/tPUxEnjxvGZ4uKU3DNjF4lSuxJiYm5uSu4zaDLg93njnWkQ9/REkZGoXxEE7F0v0PzysxjRCGUnn2V1KLIHHqqvJ9vh4JYpU6jIyUnJPRujYSoqK1Nyr1RSu9LLiAci/D5v6gbXZFSrXxYkO99BxWx71gvEImGaX36WooOOIK96stXhyDBtDgZ5PhDgS5WVKUsEROzCvhO1RXZRt9ysF8ixZ71A+5I3iLQGKHNpOmE6+mN9PTkOB18dP97qUERSznZd/C63ZzrwZ+AEYLrf590Yx3MuBW4Bppk7LP7W7/P+MjURSyoEW6K0bAwz4yT71gs0+Z8id/wkig46wupQZJg2mb0CF1ZVUa5eARmFbNUz4HJ7POY0xk3DeM555sZJlwBjgW8BV7ncHr0iZ5C65b31AtU2rRfo3riO7nUrKTv5LBwpmAYkiaVeARnt7NYzUAYcD0wx39zjcTtwk9/nXWx+/5z5IRnk43qBWfasFwj4n8KRm8+44061OhQZpo3BIC8EAlxUVUWZegVklLJVMuD3eR+m96/9KfGc73J7JpjrHWS53J6lwGxgPXCn3+f9Z9IDlpSpWx5k/IH5OLPtVy8QaW+l9a2XGXfcqWSNKbI6HBmmP9bVket0col6BWQUS1kyYE5BHPSV0u/ztuzBZaean68CLjDrBS4DnnC5PTv8Pu9rgz0x0FBPLBrbg1sOLhIJ01i/PaHXtAOr2xVsMWjdHGbi0bGExpGodnW+/AxGOIzzkKNt8fO3+ueVLMlo1+ZwhBebm/li0RhiTQ00JvTq8dHPK72ke7sqqgfeQSCVPQMnAv7BHnS5PQV+nzc4zGv2xf8jv8+7zvz61y635yJzdcRBk4GyJOwk11i/fdB/6HRmdbvWf9QBNDLj2CoqqhM3TJCIdhnRKIF3XmXMfodQffCRCYttJKz+eSVLMtr1yw0byHM6uXL63pRaNESgn1d6ydR2pSwZ8Pu8C4BE9/H2JfLNuxxfb26uJBmgblmQnEIH5fukZiWu4Wh//y3CTQ1UX3y11aHIMG3o7ubF5mYuGT/eskRAxC7Svey5xkwIjtrl+Exgg0UxSYLV99ULZNmvXiDw36fIKa+i+JB5Vociw/RQfT35TidfUa2AiL0KCOPhcnvuBqb6fd6L/D5v1OX2/AK43eX2vA+sAC4HDjE/S5rraozQtjXC7DNSs078cAS3bKDzo2VUnXcZjqwsq8ORYVjf3c1/m5v56vjxlGan3cugSMLZ6n+By+1ZYy4c1Ndjscbl9hjAI36ft+/NfYJ5Tp+fmuf/Gyg3d1X8vN/nfd+CJkiCfby+wFz7rS8QWPA0jpwcSk843epQZJgeqq+nwOnkYvUKiIDdkgG/zzs7jnPm7/K9AdxpfkiGqV8WJLfISdkMe9ULRDs7aHljAWPnnUR2cebtzJbJaru78Tc3M1+9AiIfS/eaAclwvesL5NmuXqD59RcwQkHtQ5CGHqqrU6+AyC6UDIhtdeyM0L49Yrsti41YjOYF/6Fw1v4U7LWP1eHIMNR0d7OgpYXzKysZp14BkY8pGRDbql9mz3qBjhVLCO3cTtnJ6hVINw/V1VGoXgGRT1EyILZVvzxIXrGTsr3tVS8Q8D9F9rhySo44zupQZBjW9fUKVFUxVr0CIp+gZEBsq255kPFz83E47VMv0FO3lY4VSyg96UwcekNJKw/V1THG6eSiqiqrQxGxHSUDYkvt9WE66iNMsNkQQWDB0ziysin97BlWhyLDsK6ri5daWrhAvQIiA1IyILZUb8P1BaLdXbQsfJGSI48nZ1yZ1eHIMPyhvl69AiJDUDIgtlS3LEjeWCele9lnzfjWRX5i3V2UnaLCwXSytquLl1tauLCqihL1CogMSMmA2I5hGNQvCzLBRvUChmEQ8D9N/vRZFMzY1+pwZBj+UFdHUVaWegVEhqBkQGynvS5CZ0PUVkMEnSvfo2f7ZspdX8DhsEeCIru3pquLV1pbubCykmL1CogMSsmA2E5fvYCdFhsK+J8mq3gsJUedaHUoMgx9vQIXqldAZEhKBsR26pYFyR/nZOxUe9QLhBrqaX//bUo/ewbOXHuteSCDW93VxautrVxUVaVeAZHdUDIgttJXL1A9N9823fGBBf8BB5SddKbVocgw/KGujuKsLC5Qr4DIbikZEFtp2xahqylqmyGCWE+Qlteep+SwY8gp15tKuvioq4vX+noFsrKsDkfE9pQMiK307Ucw4eACq0MBoPWtV4h2tmt3wjTT1ytwvnoFROKigTSxlbrlQQrKsyiZbP2vZu90wqfImzKdwjkHWR2OxOmjri5eb23lGxMmqFfABgzDIBAIEIvFrA4lIYKhCA0NDVaHsVtOp5OysrK4h1utf8UVMfXWC3Qz4eACW9QLdK39kODmWiZcer0t4pH4PLh9OyXqFbCNQCDAmDFjyM+3x9DfSIXDIXJy7F9IHAwGCQQClJeXx3W+hgnENlq3hOlujlFtk3qBgP8pnIVFjJt3ktWhSJxWdXaysK2Ni6uqKFKvgC3EYrGMSQTSSX5+/rB6Y5QMiG38X72A9S8c4UAjbUsWUnrC6Tjz7VG/ILv3YF0dY7OyOE+9AmntnAe7OOfBLqvDGFU0TCC2UbcsSGFFFsUTrf+1DLz8DBgGZSd/3upQJE4fdnbyRlsbV0+cqF4B+YT6HTu54pvXss/MGQCEw2H2mjaVa6+5iqw9/F25486fcfbn3cw96IAER2sN6191RYCOHRG2vxdkymesrxeIhUM0v/IcRXOPIrdqoqWxSPz+0NcrUFlpdShiQ5MnTeLn9/zk4+9/9vNf8/KrC3F9TquKomRA7CDUGWPBbTswYgYHXTDW6nBoe+d1om0tlGt3wrTxQWcni9rauGbiRMaoV0DisO+c2Wzbvp0H//gXVn20hnA4zJnuU3Gf5uIvf3uM7mCQLVu2sq2unmuuuowjDj+UJ/7l5eVXX2dC9Xg6OzsB6Ozs5N5f/pa2tnZisRjXfONyZs7Ym4vnX8GZZ5zGK68tZL85sxk7toTF77zLfvvO5lvfvMLq5n+KkgGxVCxq8OpPGmjZFMZ193jGTbO+Sjfw36fInTCFMfsfanUoEqeHzF6BL6tXwNa+/0yQldt3X9S2si4KZu3A7uw/0cmPPz+8OqNIJMJbi5dw5umnsK2unl/9z12EQiG+etk3cZ/mwul0snNnA3f9+Ae8s+Rdnnv+v+w7ZzbPPv8iD/7u5zgcTi75+jcAePLpZ5m9zz5ccN4XWbuulvsfeJhf3HsnALNmzuDcc87mnPMu4fbbbuErF57HuRfM5+qrLsPptFfJnpIBsYxhGCy+L8C2Jd0cfX05kw6zvlCvq3Y13etXU33JNThs9p9VBtbXK/At9QrIELZu28aNt9wGwIaNmzj/3HM4et5R/Omvf+f6m75HVlYWLS2tH59/4AH7AVBZWUFHZyfbttcxdcpkcnJyyMnJ/bj+YO26Wi46/0sAzNpnBlu3b//4GrP2mUlWVhbFRWOYNXMGWVlZFBYW0BMKUWCzGRZKBsQyq55sY/V/2jngyyXMPqPY6nCgbzphfgHjjnVZHYrE6cG6OsZlZ6tXIA3E+xd8X4/Ak1cWJuze/WsG7rjzZ0ycOIFlyz9g+YqV/PyeH+N0Ojn7ixd+fL6zX2JpGAZg0L+aqffYpzn6nZWV5RzwawZ5rpX0p49YYvObXbzzQDPTji3k8MtKrQ4HgEhrM22LX2PccaeQVTDG6nAkDis6OnirrY1LqqooVK+AxOnyr3+Vh//8CE2BAOOrKsnOzuaNN98mZhiEw+EBnzNhQjWbt2wlEonQ1dXNunW1AMyZtQ/vLVsBwKqPVjNt2pSUtiVR1DMgKde4rofX7mqgYlYux3+3AofTHqv7Nb/yHEYkrH0I0oh6BWRPTKgez7HHzGPdulrq6ndw03e/z5FHHs7R847kd79/iPKysk89p6S4mFNOPokbbr6NCROqmTVrJrFYjC+c5ebeX/6WG26+FYBvX32lBS0aOcdgXR2jQMIb3li/nYrqzJuKlsh2dTZEePaaOhxZcObvJlBYZl0+2r9dRiTC2hsuIm/yXux18z2WxZQIo+X3cHlHB5euXcu1kyZxyfjxlsY2Epn+82poaKBymMlaMoYJEiVdliMGBvu3H/CvL9v1DLjcnunAn4ETgOl+n3fjbs6fA/wMOBrIAVYDd/h93udSF7XEI9zVO4Uw3B3D/StrE4Fdtb27iEhzExPnX2d1KBKnB+vqKMvO5tyKCqtDkQSzYxKQ6WxVM+ByezzA28CmOM93AC8C7cAMoAJ4FHjKTBLEJmJRg1fvbKB5Q5gTb6ukbG97ZdYB/1PkVFZTdPCRVocicVjW0cHi9nYuGT+eAtUKiIyYrZIBoAw4HngkzvOrgKnAY36ft9Xv84aBh80ej7lJjlWG4Z0HAmxd3M1nrilj8pH2yvqDm2rpWvMBZSefjcOpN5Z08AezV+BL6hUQSQhbJQN+n/dhv8+7Zhjn7wBeAy51uT0VLrcnD7gSaAJeTW60Eq9VT7Xxkbed/b9YwpyzSqwO51Oa/E/hyM1j3PGnWh2KxOF9s1fgq+oVEEmYlA3autyebKBosMf9Pm/LHl76y8DzQINZFNgIfMlMFAYVaKgnFo1/e8d4RCJhGuu3x3FmehlJu3a8H+Od+yKMP9TB9LO7bfXvE4mE2bl+LS2LFlBw2DG0dHRAR4fVYY1Ypv8e/q4xQKnTyYnRSEa0M9N/XsFQhHA4ZHU4iWMYadOeYPenX3MHK1ZNZQXXiYB/sAddbk+B3+cNDueCLrcn16wZWA2cDnQAXwWecbk98/w+74eDPbessnq48e9WplcFD1egNsR799dRPjMX14+qySmwVUdU73+S5W9DJMzEsy8kP0N+dpn8e7ilqIRl2+u5YdIkJqfxDIL+Mvnn1TebYLjV91esXQvAH2bNSlJ0ey6dZhPkFxRQEedMjpQlA36fd8FgUxpG4GTgYOB0v8+70zz2e5fb8w3gUuCGBN9P4tTVGMF/2w5yi5yc/OMq2yUCAEYsRvOCZyiccxD5U/a2OhyJwwPbt1Oenc0Xta6ADMPWbdu5/4E/0trWDsC+c2ZxxWXzyc3JGdZ13nz7HQ4+6IBBk4EX/S/z10ceY8KE3j82u7uDnH7qyXz+jNMGvebF86/god//moICa5djt8/crpHZNcnITsY6AhKfcHeMBd/fSagjxhm/mkBhhT1/zXo+Wka4sZ7xF9hvBzH5tBU9IZZ2dHDj5Mnka98IiVM0GuWOO3/G1VddxtyDDsAwDO574I/8/bEnuPSrFw/rWv/r/Q8H7Df0RLUTjj+WKy+bD0AoFOIb37qRIw8/lPHjq0bUjmSz56v0EFxuz93AVL/PexHwJlAP/NTl9lwHdAIXA7PNQkJJsVjU4LW7GgjUhvjcj6som2Hf7rTuRX6yyyopOewYq0ORODzS3kF5djbnaAaBDMO77y1j6pTJzD3oAAAcDgeXX3rJxxuRPfnUM7zy2hs4HHD0vKM4/9xzBtzCuKW1jY9Wr+UHd9zNvXffQU4cvQq5ublM32sa2+vqqago59e/e4C6unpC4Qhfu+RCDp574MfnbtiwiV/97gFycrJxOBx8/3s3sXHjZv715NP8+Pbv8eHKVTz2+L+568c/SMq/k62SAZfbswaY1m+WwxqX22MAj/h93svNYxPMc/D7vC0ut+dU4C5gLZBrfj7X7/MutK4lo9fSPzSz5a1ujrqmjClH2WsKYX892zYRWreSqnMvxaGKdNtb2t7OilBIvQJp7n+2bGFtd/duz1vT1bsCYV/twFBmFRTwnSmD7wewddt2Zuw9/RPH8vLyAKir38GLC17mvl/di8Ph4Jrrbub4Y+cNuIXxD7//Xf76yGPc8YP/F1ciANDc0sLadTVc/Y3LeOW1hZSVlnLDtVfT2trGzd+7nQfv++X/ndvaypWXz2e/ObP5298f5+VXXucLZ53BC/6XePe9ZTz6j3/xnRuuieu+e8JWyYDf550dxznzd/l+BXBmUgOTuKz+Txsr/7eNfT3F7PcF+00h7K/J/zRkZVN6otvqUGQ3NgWD/HDTJsqdTvUKyLBFIhFisYFnjtXUrmf/feeQnd37VrjvnFnUru9d9HbXLYzj9drrb7B2XQ2hUJiWlhau+eYVlI4bx5q1NSxbvoIPVq4CoKen5xObIpUUF/PHP/+NcDhMY1OAz514PABXfP2rfOv6Wzj15JOYOGHCCP4lhmarZEDS19Z3unj7dwEmH1XAkVd9epMPO2l7dxHNrzxL/mHHkF0yzupwZAgfdXXxrZoaDODHZaXqFUhzQ/0F318iZxNMmzaVZ30vfuJYKBRi2/a6T51rGAZOR+/v2Ke3MI5PX81AT08P3/z2d9h7+rSPHzv/y1/kc589YcDn3f/gHznv3HM46ojDePxfTxLq6Z2+2NXVRU5ONo1NTXHHsCf0P0tGLLA+xKs/aaB0eg4n3lqJM8seuxAOpOPDd9n6u59QsNcsis8aXvGQpNY77e1cuXYt+U4nf5o1i1m5w6v8FgE49OCDqKur58233wHzjf2Pf36El19dyKyZM1i5ajWRSIRIJMJHq9cyc8b0Qa/lcDqJRqNx3TcvL4+LLvgyv//DnwDYd/Ys3nyrN4bmlhb+9Ne/f+L89vYOJk2oJhQKsfidpYQjEQDue+Bhbr3lRpqaAqxaHfeafMOmZEBGpKspwoJbd5BT4OTkn4wnp9C+v1Jda1ey+Ze3kzthMlNvugtnvrVTeWRwLzU38+2aGqpzc/nTrFlMy8+3OiRJUzk5Odz5o9vwPf9frrnuJr59w3fJy8tj/lcuYPz4Kk475XNcf9OtXP+d73HaqScPWfU/98D9ueXWH9La2sadP/05PT09Q977pBOPIxBoZul7yzjh+GMoLCzk2hu/y/d/eCcH7LfvJ871nH0mP7rzHn589/9w7jln8/Irr/PIo09QWVHOjL2nc8Vl87nv938kGo3Gde/h0hbGCZTpi4fsKtwd44Ub62nZHOb0X1ZTsU+eJfHFo3vjOjbe/R2yS8ax162/JGdc2aj7eaWLJxsbuXvzZg4YM4ZfzZjBWHM8N93bNZhMb9eebGGsRYcSI623MJb0YMQMXr+nkcZ1IT73oypbJwI92zax6WffJatgDNNu+Rk54+xd0zBaGYbBn+rrub+ujmNKSrhn770pUI3AqGTHJCDTKRmQPbL0oWY2v9HFkd8oZerR9p1CGNpZx8Z7bsbhdDLtuz8jtyIzlrDNNDHD4Bdbt/KPhgbcZWX8YNo0chz2rT0RyTRKBmTY1jzbzof/amPOWcXsd459pxCGA41s/OlNGKEQe936c/KqJ1sdkgwgbBj8aONGnm9u5sKqKq6fNAmnEgGRlFIyIMOybWk3b/2miUlHFnDU1WU4bPqiHWlvZdM9NxNtb2Xad+/V3gM21R2NcsuGDSxqa+PqiRP52vjxtv2dEslkSgYkbs0bQrxyx07GTbP3FMJoVwebfvZdQg31TLvpLgpnDL2WuFijNRLhutpaPuzs5NapU7WgkIiFlAxIXLqboyy4bQfZ+b1TCHPH2LOwKxbsZvPPbyW4ZT1Tr7uDMfsebHVIMoAdoRDX1NSwpaeHe6ZP56TSUqtDEht5/obeBYFO/0XyVtyTT1IyILsV6TFY8JMddLfEcP+imqLx9vy1iYVDbP7ND+la9xGTzGUeJwAAIABJREFUr76V4oOPsjokGcDGYJCra2poj0T47cyZHFFcbHVIksHqd+zkim9eyz4zZ2AYBj09Ic7/8jkcd8y8YW05XLt+A39/7J9875brcbk93Pbd73DC8f+3ydmP77qXltZWfn7PTwDwv/QqTz71DHl5uUQjUc790hc4/tijuftnv+Tss9zsN2e3q++nlD1f1cU2jJjBsgciNK4xOOmHlVTMtucUQiMaZev9d9H5wbtMvPw7jD1q4CU/xVqrOjv5Vm0tDuDBWbPYt9C+M1Ekc0yeNOnjN+mOjk6uvPp6jjz8UBjGlsO/ue8PfO/m6wGYUD2e195Y9HEy0NPTw5at2yguLgLgw5Uf8fQzz/Gzu35EcXERzc0tXHfT95i+1zSuvGw+t97+E+7/zf/Yqj5GyUCaiUUMNr7eSbB14I03Eq1xbQ91SwyOuLKUaceMSck9h8uIxdj20L20L32D6ouvpvT4T2f1Yr3FbW18Z/16xmVnc9/MmUzVqoJigaKiMZSVlRJobvnUY/23HO6fDKxavYbScWMZP76KcDhEZWUFDTsb6Q4GKcjP550l73HA/vuyafMWAJ76z3NcctH5HycHpaXjuO9X91JU1PsaOnnyRN5ftoJDD5mbsnbvjpKBNLJtaTeL7w/Qujkcx9mJs5fLyf5fsucUQsMwqPvbb2ldtICqL32N8lM9VockA1jQ3MxtGzcyNS+P+2bOpDI3PVZwk8RafH8TgZrQbs9rqu09p692YChlM3M56pvlccdQV7+DtrY2qio/XbDaf8vh/pYt/4CDDtz/E8eOPOIw3l68hM+ecByvL3qT0085+eNkYMu2bczYZY+DvkQA4OCDDmTZig+UDMjwtG4Ns+TBAFve6qZ4YjYn/aiK6gNT1F3vhPaOHbbqzupjGAY7nvgjzS89Q/kZ51Fx1oVWhyQD+HdDAz/dsoWDzOWFS7L1siOptXXbNm685TYMwyAnO4dbvnMtWeauhINtOdxfY2MTcw864BPHjj/uaP78t0c5Zt5RNDY2MXHi/xU7RiPRQbdNBqioKGflqo8S3s6R0P9KGwt1xlj+aAurnmzDmePg8MtL2c9TQlZuat+Y2ztSeru4NT7zD5qee4LSk85k/HmX2TJhGc0Mw+CP9fU8UFfHsSUl/FTLC4968f4Fn+jZBP1rBnY11JbDQ5k2dQo7dzbyxqK3OfzQQz712Oo166jsN11285atVFZWUGAOj9ltWyD9z7QhI2aw9vl2npy/lQ//2caMzxXxxb9O4sDzxqY8EbCrpv962fmvPzH2mJOZ8NVvKxGwmZhhcO/WrTxQV8cZZWX8z4wZSgTE1nbdcri/iopymv5/e2caHkWVNeC3OwSSIFtIEEEDBARRQERRcUFEy6VUxtLRcRn9HFccRUXEZUDREWEUxREVHfcF99FyoxTLcUFBZVORTWUR2cKWDUJClu7vR05j03SSTtKku5PzPk+eTtdy65zqqnvPPffcc7fm7rH96COPYNprb3LcsUfvtn3YWafz0rTXyZO4hLy8fP553wNs3LgJgK1bt5KREfnQRkOgnoE4Y+PiEr57PJetv5SSeXALTro3ncyD4jOCP1bkzZxBzsuP0+rwY+l81Wg82sjEFWU+H+NWr2ZGXh4Xd+jATZpeWEkQhg45nvc/dJi34AeOGPBHjpL+/frw5n/f5Zyzz9rt+MHHH8PMr2fTJesAcqShBzi0bx8uvuA8xowbT4sWlfExw6++nK5dsgBY+NMSjJOGNJhekaDGQJxQtLmceU/nsfKzItLaJzH4jgyyh7bUHm8IBXNmsv6Zh2jZ53D2v24MHhn3U+KD4ooKRq9axTeFhVzfqROXaXphJcZ03LcDU6c8GHbfqcbQPbb9+8GJe2w7uPdB5OXns2nTZtq1a7NryKFrlyyeeXLKrusED0UMOeE4hpxw3B5l5RcUsGbturgKHkSNgdhTvtPHorcK+en1AvwVfg69uA19L2hDcqr2dkPZ9sN3rJs6gbQDe5N14914E2RN8aZCfnk5Ny5fzpIdOxiblYWl6YWVOhKPmQdvuO4annz6ee649aZ6lfPkU89x3bXxF+OkxkCM8Pv9/DZzB/OeymX7xgq6Dk7jiKvb0apjcqxFi0uKlv7Imin30OKAbmSNug9vSmqsRVKCyJH0wut27uT+7GyGhkRjK0qi0z27G3eNuZWyspqnRlbH7aNHRk2maKLGQAzYunwn303NZePCnbTLTua0BzPYr782blWxY8Uyfp88luaZHekyeiJJafvEWiQliFUlJVz3669sr6jg0R49OELTCytKwqHGQANSkl/B/Ofz+PWj7TTfx8ugG9PpabaK29X/4oGSNSv5fdIdJLVqQ5fbHqBZa+1xxhOLioq4YflykjwenurZk4M0vbCiJCRqDDQAvnI/S98r5IeX8ikr9tP77Fb0v6QtLVpp8BsyZJJbXs6qkhJ2BCXq8G5aT5uH74RmyRRcdyc5SclQUBC16xaWlNA6iuXFBX4/BSUltMrPxyfffYBf7rMfwn5HpgMG/+8POi7cOTt9Pl7cuJF0SS98gKYXVpSERY2BvczaOTuY82QeBb+X0enwFI76ezptuzTNwLdAo7+ypISVxcWsKCnZ9X9BRcVux6Zvy2fsG1PZXl7GhPOGs76gCApWRF+o3D3zkzcKGkivXqmpPNKjB5nJGuuiRI9V990MQLcxk2MtSpNBjYG9RMHaMuY8kcva74pp1bkZJ4/vwP5HpcZdBOneIresjBUlJawqKWFFcTEr5TO40W+VlER2SgontWtHt5QUslNSaJ2UhL8wH179N5SXwcjx3J+VvVdkzN+6hbbtG1/Ee8HWLbTLyMQTyCrm8eAFPIDH49m1PfR7XY5L8Xo1h4AS1+yNJYwBFi9Zxk233MGTj02me3a3PY6viZemvU7rNq05+ywz7P6ZX89m8HHHMHfeAnI2bgor07ffzWX+9z9y3fArw5ZRG9QYiDKl23388Eo+S+1CkprHLoVwQ5EnjX6gsV9VUsKKkhLyy8t3HbNPUhLdU1IY2rYt2ampZKek0D0lhYzk5D2Mo4qibax69B5K83Lpetu/SOvZJ8xVo8OWbQVktIzPlRjrw5ZtBWTo2L2i7CLaSxgDfPbFTLIO2J/Pv/yqTsZAdZSXl/O2/T6DjzuGgSJnOI4+aiAfzfiUpct+ofdBPet1TTUGooSvws/qzyv45e21lBT4OPDUfRhweVvS0hvHLQ5u9FcGevohjX5Lr5fuqakMadOG7jU0+uGoKN7B6kl3ULphDVk3j9+rhoCiKE2TaCxhXFFRwdezv2Xs7bfwwEOPcOXfLgXghZdepbikhDVr1rJuQw7XD7+SgUcM4L/vvMfMr2fj9/k5cuDhXHLxX3aVfe+ESZxhnsqA/v0oLSvjimtGMPDww1j122qmPP4fevU8kN9W/841V17GW++8x6zZ31JRXsEVf7uE/of25cwzTuO9D6arMRAP/P7zRn4ffyfJpbB/qp9t+/nJ3QCf3hdryaKD3+MHb+WqGi2A3n4PB/s8eHwePD4vHp8Xr88D/t0b/DxgXi2u06o0j9Y7t/BG79tZOrs3zN4RZU12p6w0leTme/casUD1Siwau14X9S2nP5XDg+X2E/jWr6zxXP+6yvigJXffXOOx3k7ZNLOurXL/1q0V7Cz3s3xzpQxbNm9ka14hhbRj0zYf+Tt8u/YVFhaweNlyzPMv37UN4H+zf2S/br1ZvrkCv8/LkkXfk9mxM6n79qJ56j588s0Ssnv0Ir8E1q3dxDU3jGHRjwt4490ZtOtyKFu2+xg+ahzJyc25c/R19D/eJHeHj1Kvjz4DB/PBjK9o3fkQFv34A70OOYyjThzGj4t/4dyLr2LR3C9All6eOXMWj0z+FxtyNvL6m2/T/9C+9OtzMA9PmVrjfaqJuDIGDNPaF7gfOA1IBRYD/3Ad+4tqzkkFJgHnAa3lnFtdx/6soeROaZNEuSeFvAw/xfvIYGpjogK8FUl4fR48FV48fg/+UCWjkDBxZ2oaH3W7gqXtj47gaEVRlMjYmLOeyRPvAj8kNWvGZdeM2LWE8fw5s1n92wrKysrYVlDABZdeSevWbXY7Pz83l569D9n1fe43X3HEUZWphgcOOp55384iu0cvAHr07A1Au/YZFO8oAsCblMSUSffi9XrZvq2QHdv/WAr24L6H8c4b06goL+fHBXMZdHz4NQuWr1hJz5498Hq9dO60H6Nuuh5kgaXy8nIqKip26VQX4soYAN6TDmV/IB8YB3xomFZP17HXV3HOVGAQcCKwGhgOTDdMq5/r2L82hNAdOmbQ4aUpbMlZT0bHTg1xyQalIfU6u0GuUsmWnHwyOrZrwCs2DKpXYtHY9dq8uYjMTGmkrr4+onOjOZtgH18SWft3ZurkPV21K1p5OWnI7ksYH9u/G5kZuzeqbVI9dGztpUdmEtu3b2fxwvlsXLeKb778mPKycrYXFXH7jVeQnualdZtkemQmkVTkJSUZWpZvZtZnH/HEo5NJS0vl8muup0v7JH5O89K6lZdeHZszaGB/CtctJnfjGk4ZdDA5GzfRopmHzm29LAqSw78X1z2OmwT4hmkFevUjXcfOcR27RLwELYGwXUXDtNKBvwK3uY69xHXsItexHwKWilGgKIqiKDUS6RLG382dT/9+fXn6iSn857GHefY/j5K1f2d+XLgoTKmwbXsRbdu2IS0tlSXLfmbLlq2Ul5XtdszJQ4fwwsuv0b9fZZyU1+Ohwrf7dOse3bNZuuwXKioqyMvLZ9w/KxdUKi0tpVlSUr28AsSTMeA6dqHr2Fe4jr0saHNgTllVXoEB4t2YG7J9TlUGhKIoiqKEY+iQ48nNzWPegh92296/Xx8W/rQYgC9nztpjtcNTjJP4/MuvwpbZPbsrLdPSGDn6H8ya9S3Dzjydx558Zrdjeh7Yne3btzP0xMEApKe3o7ysnPET/1htcb+O+zL0xMGMHP0P7rpnAn+SKYmLliyjb59DqC+evel2CMYwrWZAlUnlXcfODzm+NfAVsMZ17DOrKPMi4BUgxXXsnUHbxwMXuo7dvarr5W7O8fsqfFXtrhPl5WU0a9b4kq+oXomF6pVYNHa9ikvL6Sjz+CNlzf23AXDAbffvJelqz6hb7+S20TfSIaM9RDm3xtp163ls6tP8675xtT53wv2TOefsszio14F77MvZkENq892jATI6dgorfEPGDAwB3Kp2GqaVKkMDGKbVBfgQ2ARcWIdreSR7apWkZ9bu4YwEjRlILFSvxEL1SiwCem3evJnkWi43nj324b0mV125ccRwnn1+GnfcelOt9amO6R99wvSPPmH0zSNqXe6cufPJyMio0jOQkppKRmZmRGU1mDHgOvankcTZG6Y1UAwBGxjhOnZZNYfnyGcmsDZoe4egfYqiKIpSL6K1hHEoZ5x+Cmecfkqdzj1y4OEcOfDwqMgRV7MJDNPqA3wM/NN17EciOGU+UCqzCd4K2n6MGBSKoiiKotRA3BgDhmklAS8CU6szBAzTmghkuY59sevYBYZpPQtMNExriUwtHAV0kSmHiqIoiqLUQNwYA9K7HwD0MUxrdMi+l13Hvkr+308a+wAjZQri50Ar4AfgFNexVzeg7IqiKEoYvF4vJSUlpOgS1w1KSUkJXm/kEwYbbDZBHBJ1xRt7IFBjQ/VKLFSvxCKgl9/vJzc3F58vurO3YkVJcTEpqamxFqNGvF4v6enp4daFiflsAkVRFKWJ4fF4aN++fazFiBpbctZHHKGfSMRN0iFFURRFUWKDGgOKoiiK0sRRY0BRFEVRmjhNOYBQURRFURT1DCiKoiiKosaAoiiKojRx1BhQFEVRlCaOGgOKoiiK0sRRY0BRFEVRmjhqDCiKoihKE0fTEdcTw7RSgUnAeUBrYDFwq+vYn8VatvpgmNa+sgDUaUCq6PUP17G/iLVs0cIwrWOBmcC9rmPfHWt56othWpcDt8lCXuuBR13HfjjWctUHw7QOAh6QZcmTgWWyxPn0WMtWWwzT6gY8D5wAdHMd+7egfQlbj9SgV8LWI9XpFXJco6hH1DNQf6YCJwMnAhnAa8B0w7QOjLVg9eQ9YF+gv3x+CXxomFajWFFFKt/nge2xliUaGKb1F+Au4FKgDTACGG6Y1sBYy1ZXDNPyADOAbUB3eb9eAd4VIyFhMEzLAr6VZdbDkZD1SAR6JWQ9EoFegeMaTT2inoF6YJhWOvBX4M+uYy+RzQ8ZpnUxMBwYFWMR64RhWoGeySTXsXNk2/3A7cDRwDuxljEKTJBe5vpYCxIlxgGjXcf+Tr5Pl79EpgOQBbzqOnYBlc/hs8AjwKHy+yUK6cBg4AAx2HaR4PVIdXolcj1SpV4hNJp6RI2B+jFA7uHckO1z5GFPSFzHLgSuCNmcLZ8J/9AbpnUccAnQV3pgCY1hWvsBvYEkw7TmAb2AlcB9rmO/GWv56orr2BsN0/oSuNwwre/EQ3ANsBWIezdzMK5jP0vlb3VAmN0JW49Up1ci1yM1/F7QCOsRHSaoHx3kc2vI9i1AxxjIs1cQC/95YLrr2N/GWp76YJhWmuhyk+vYG2ItT5TIks/hwIXy7D0HvGGY1gkxlq2+nA90BTYDxdKr/LPr2BtjLVgU0XokwWiM9YgaA3sHD9AoFn0wTKsLMEsqpgtjLU8UmAAscR17WqwFiSIBD989rmP/6jp2kevYj0hP828xlq3OGKbVXGIGfpHx5n1kOOQDw7T6xFq+BkDrkfil0dUjagzUjxz5zAzZ3iFoX8IiwWdz5CU+zXXsbbGWqT6IW+8icTU3JrbIZ17I9pXAfjGQJ1qcLIFnI13H3uQ69g7XsZ8AVgGXx1q4KKL1SALRWOsRjRmoH/OBUmAQ8FbQ9mOAD2MoV72RntfHMo3rkVjLEyWukGlbiwzTCmxrAxxpmNYw17EHxFa8OrNcDIKjgB+CtvcA5sVQrmjhCfnerLH0mAWtRxKLRlmPqDFQD1zHLpDo5omGaS2RaSijZJ731FjLV1cM00oCXgSmNqIXGOBm4M6QbW8B38hc9oTEdewKw7QmA+MM0/oeWAhcBRwmn4nKbOkZ/8swrZuAIom679WYemVajyQcjbIe8fj9jcnAbngM02ohSTUuAlpJz2yU69izYy1bXRE32FfSWwl9QF52HTuRG5jdMEzrC+CLRE4Wwh9z8v8hjWR7me50p+vYTqxlqw+GafWT8dmjgOYSPzDRdex4npa2B4Zp/SyNu1eSJwXerZddx74qUeuR6vQSQyAh65Gafq8wxyd8PaLGgKIoiqI0cTSAUFEURVGaOGoMKIqiKEoTR40BRVEURWniqDGgKIqiKE0cNQYURVEUpYmjxoCiKIqiNHE06ZASV8TTfF3DtCYA1wMLXcc+rg7nx40udcUwrbHAZa5j94ixHPcCJwJDgE8T8b4apvU1sNx17MvC7BsMfAIc7Dr2ygjKqvZ3kXnyr7qOfU8EZQ0BPgcOdB17eS1UihjDtEqA4a5jvxDlcveXLJv/5zr2jGiW3dRQY0BRwmCYVitZIW8kMCXW8jQUUrme4jr2c1RmxxsPjI+xTKcBI4A+rmOXB6WAbTS4jj0TSKnF8bv9LoZpnQmsdx17gezvtbdkrQuuY0esG5X6/BWYXZNh5Dr2WsO0rgReNUyrr+vYcb00cjyjwwRKo8YwreQ6npouOfEXu44d95m5DNPySPrX+mLF0yJAotNkSWm7NtbyxDH3AAmZEz8Uyab5MJAdyfGuY38ILA6TIlipBeoZUMIijWgpcDFwNnC65IZ/1HXs++SYF4AewS50w7SmAfu7jj3EMC1DXJ9Dgcfk5f5GUq6Ok6VMd0p62X8HXT7FMK1ngHNFBlvWDS+RawwFJgJ9ZQ34/wG3uI69Rfb7JX/4tcAG4IQw+qUB9wHDZHW4FcCDrmNPM0zrGHGbAjiGaX3vOvZRYcroIpXWIEmVuxC41XXsuZHoYphWijR0ZwNtgY3A03I//IZppUqu8/OAlsCvwH2uY78t178bOBNwgRsAyzCtGUBX17FXB8n5IpDtOvbxhmn1Av4t6X2TgB9FngWGaU0EbgU84tYdDJjAla5j7x+JzoZpzZLV6UplXYQU4CPgb65jF9ekc5hH8VSgN3BSyPY6PyOGac0Efncd+69B92hfYB0wzHVsxzCtc4AxwIFAruhwm+vYhYZpHShpkU8FRsu9yAXGuY79vJTXFvgPcApQDFSbmz/UVW+Y1jpgEnAQcL6kwn0duF6ejbsDv4thWmuBzsAThmmNcB37UMO0fgOmuY49Vsq/Fbga6CT34yVgbCSGrvymrry/w+SdtWU1yWL+GOaYCBwi9cSXwM2uY+fwxzt5levYzxim9Yq0PV8BtwHtgK+BS+XcXHm2HMO0PnYde5hhWhdIuu1uQBnwhdyLgCdgingHbncdu6AmnZQ9Uc+AEhbXscvk37ukImsDjAXGG6Z1SITFBMoYIWO93WWRmW+kcs4QV/wkw7Qygs67WvZ3FCPkXOn5BBqjD4EnpTE5Uo57LeTal0vFdWIVsk0FDOAMyeX/T+AFw7TOk3zwATerWYUh0EzGrsukwekCfAt8aZhWViS6ADcBx0mPrqU0+jdII4M0mocDR0uFOR54XdaOCBDIn95OrrNBGo+AnC2k4X1RNv1XKtwskWkV8A6Vv/kdklN+tuvYKa5jz6mDzmXAZcDvwAHAsXL9gLehJp1DOR34yXXsDSHb6/OMPAWcKw12gPNlUaQZhmmdIAvPPCjLCp8kDX7g/MBzfa8MI7UGXpDGOF32PSTLLw8Qg6KtfI+UMuAWYLrI8Bfg72Kc7UbAUAOudR370ND9hmmdK+s7/J/c87NF7j1iF6qRZYQ8J+2B04Bz5J3BMK0e8lu8I8bG4SKzK738cOWdKMf0FIOnPzDadeyikHdvmGFanYFpYji0lvvpl98nwP/EuDUi1EkJQT0DSk287zr2LCpf+jekF3eIuOUi5UnXsTfzRy9jX9ex/xtU5nNiKGyR4791HTtQ8S6QnsQ5Uhn8HZgf6IEBGwzTGg0sNEwrO2iMcYbr2MvCCWOYVmvgEuDioGPeMUzrE6kg3wp3XginyRLBg4M8EuPEG/EX6dXVpEtHwAcUSw9tnmFa+0nPr5X0rE91Hfu3IBnfl4bwa9mWLt6CUpHhNWnYAtc3ZaGVgE5HA76gHt0bwKWGaXUM9OKioPNvrmM/Jf8vMkxrsTwzVKdzFdfsJ96LUOr8jIhB9Ih4qAKrAl4IvCgrQI4A3KDyV0gw6RuGaXUIkuFF17EXyX14S9zUPYA5cj/GuI69Kug+Da/h/oYy23XsD+T/Tw3TypX7OL2W5dhApuvYefJ9vmFai8Q79HwN5waY5Tr2u/L/93K/LfGMDAdWuI79kOzfYZjWGOA74AhgbpjyKoB7XMf2AWulXqiqk5EpDX3gmdlqmNafg58Z17HzDNP6HThUfl+llqgxoNREcHRxsXym1bKM34L+3yFLtELlS7xDAsJSg45ZGnL+SunJIr2GY8SNHUyFuBBXBp1TFdnSmw69zs/Sy4yEHkB+cI/VdexSw7RWiWETiS4PyhDGenFdzwBeBTZJGUniKg1uKL3iWQmw1XXswqDv04CbDdPqJg3RhcC7Qa7TIcBYGS5IC/IORhLgFanOoRHpxUHPTHU6hyMTmB9me52fEdexVxqm9bKsSz/VMK2u0jBeEqTn5yHn/iyf2eJBCNVz17thmFZ76YGvCuwMuk+1obr7WBtaARMM0zod2FdiYZoDS2pRRuixq4Lud48q3iXkuQhnDKwSQyBAsXgd9sB17B8M03pYDKJF4p16S4yNYDbLkJ9SB9QYUGrCF8ExwYQbegoto6YyK0K+e4BAxe4DpruOPayGMkpr2B8oNxhvmKVWa3N+uDKq1EUioQeIW/UUic8YJ+PdgfMGBSLEq2A3PV3H/l564ucbpvW4DIOcQ2XvtJu4ch8Tj0OhYVqnSIMcTZ2r/H2r09l17O/DnOKv4jep7zPyNHCjYVp95R7NdB17RTV6Bp7rSPRsEeZYxLirDbV996riQfHqnAvMcx3bZ5hWbZdGDn2vPUEGEBHer2BqpZvr2DcbpjVJhpNOA2YapnW/69h3VSODUgs0ZkCpD8XSwwgmq4pja0PvkO/dZQwaCdzqZ5jWrmfXMK0Uw7Q61aL8FdKYhLole0v5kfAz0EbGM3fJAXQNKaNKXQzTagmkuI49z3XsCa5jDwQWyNjuSpHxsOCTDdPKkrH76pgmFf+ZQL70pJAGuDlwd5A3oTYR6JHqXCU16ByOzeIdCKVez4jr2IuB2cAFMlzwXIie4Z4NX5jeejg2iZHWNfj64rmKBYOAN13HniOGQJqM09eG0OO7B3n4fgYODtkf+H0ifZ+qxDAtr2Fa6a5jb3Ad+wXXsS+QIbQbQg7NqMbDpNSAegaU+rAUuMQwrS6uY682TMuUwLJwPbzacKxhWn8G3pcx44skIAvgCQlCu1ei35Ol5zPIMK0+Ia7HsLiOvc0wreeBMYZpzZVK7RwJFDsrQhk/kUrwYcO0hkvDHUiC83qEurwHbDZM60bXsTcZptVdgvLecR17u0TLjzVMa4FE7R8r47/XhVwjlFck2PBqiSgP9KIDlfdJhml9IGO+gWGRLBnOKQI6i6u7OKTcSHWujip1ruL4hTKrIZRoPCNPycwKT8g486MSFHkh8LYMDYwB3nYde6vEc1SJ5EJwgGslxmOr3Ke9OUV1B9DTMK32rmNvDdm3GjhaDLFWEiexFjigigC/cJwQ5n4HAvieBEYapnUL8Li4+ycAc6rw9tREkXweZJjWfPHcPGCY1tky5JAmQaG/Bk4wTKudPMML63A9RT0DSj15FvgYmGOY1hJx4T1bTyMzWaYJDZNe4XRpaB6isqJdLT3ek2X/Sql8To/EEAhipExtmiWBi6OAc13HdiI5WRrYYcA+0iNdIb0/trGBAAABm0lEQVShY4PG1KvVRXrDzYHFhmntkOlbr0hjhkyPnC73uEhc23e5jl1tw+s69hpgpkRsvxS0fa5Mp3xGdD5dZPsaeF+i6KdJ/MAKuce11bkmatI5lI+klx88FhytZ+RNMQReCwRUyvlfSyDpXUAh4IgckUbfIz3XnyX48VcgT563uua9qIkpki0zXHzFaLnnm8RL9JQYJwOlcY+Ep4E/yf38UOI8JlN5v1bJvgtlWuA38mzsMfMhEiTY+BUJSH1P/n9crlkkRmtnCdIMcJJ4bty6XFMBj98f9/lUFEVpooir/yfpld8VwSm1KbuTGAoDXcf+KZplNyYSIa22YVpfAUtdx7461rIkKuoZUBQlbpGe/ChgRHCsQn2RfADPydRZNQQSGMO0zpIYjxrXYVCqRo0BRVHiGtexP5Zx/NcjCJ6sEcO0bpdsg2XANdGRUokFspbGM8BFrmOvi7U8iYwOEyiKoihKE0c9A4qiKIrSxFFjQFEURVGaOGoMKIqiKEoTR40BRVEURWniqDGgKIqiKE0cNQYURVEUpYnz/6wuDkqIEfxDAAAAAElFTkSuQmCC\n" + }, + "metadata": { + "bento_obj_id": "139712287095440", + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "e9baed40-d083-47fb-aa14-03addfd33948", + "showInput": true, + "customInput": null + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/website/tutorials.json b/website/tutorials.json index 1534de39c1..915fc6548f 100644 --- a/website/tutorials.json +++ b/website/tutorials.json @@ -14,6 +14,10 @@ "id": "closed_loop_botorch_only", "title": "q-Noisy Constrained EI" }, + { + "id": "discrete_mixed_bo", + "title": "Bayesian optimization over Discrete and Mixed Spaces via Probabilistic Reparameterization" + }, { "id": "preference_bo", "title": "Bayesian optimization with pairwise comparison data"