Skip to content

[RFC] Max-Value Entropy Search #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 298 commits into from
Closed

[RFC] Max-Value Entropy Search #89

wants to merge 298 commits into from

Conversation

gpleiss
Copy link
Contributor

@gpleiss gpleiss commented Apr 24, 2019

No description provided.

Balandat and others added 30 commits December 14, 2018 16:20
Summary:
Fixes an issue with taking gradients w.r.t. train inputs that came up in recent gpytiorch changes.
See https://pytorch.slack.com/archives/GBSTFC62E/p1544807635222300 for details.

Reviewed By: bkarrer

Differential Revision: D13468268

fbshipit-source-id: 79634d84676adfa6e930ad06adfca05e04e5fce2
Summary:
Main change here is that we sample the inner mc samples from the
posterior over observations, not latent functions. This accounts for
hetereoskedasticity in the noise model, which affects the amount of additional
inflormation we expect from observations at the new points.

Clean up docstrings, remove unnecessary fast_pred_var

Reviewed By: bkarrer

Differential Revision: D13455187

fbshipit-source-id: 579ff9d04363bdfad2df859a2f620ac502bbc4c0
Summary: Using torch.Tensor as type annotation gets pretty verbose, and usually there is no ambiguity in just using Tensor.

Reviewed By: bkarrer

Differential Revision: D13475378

fbshipit-source-id: 06dcc90c01a47bf811972cea4602d896b154eb27
Summary: Add reinitialize method to model interface. This will allow us to run model-agnostic closed loops.

Reviewed By: Balandat

Differential Revision: D13490599

fbshipit-source-id: 05612c389302cb842088c275671eeaf07e0496de
Summary:
These are failing for me on the GPU on master b/c of numerical issues.
E.g. `torch.tensor([1.0, 1.0, 1.0], device="cuda").mean().item()` may not be exactly 1.0, but 1e-6 close.

I am not sure whether these have been failing all along. Note that sandcastle won't pick up this kind of stuff since it doesn't run the cuda tests.
sdaulton, did you run these successfully on a devgpu?

If so this may be due to some changes deep in the cuda backend. If that's the case, let's see if that was a bad commit and whether they keep on failing going forward.

Reviewed By: bletham

Differential Revision: D13517020

fbshipit-source-id: 6cec0eb4b29215756db5ae1dfd932a944b089302
Summary:
Adds the ability to specify linear (non-box) constraints on the input parameters of the scipy candidate generation.
Useful e.g. for enforcing order constraints, or optimizing subject to a budget constraint in the fidelity-aware setting.

While the scipy optimizer also allows general non-linear constraints, this diff only supports linear constraints (the optimization problems are typically hard enough, we don't need to suggest people can do all kinds of crazy things with ease...).

This is botorch-only for now, and still needs to be hooked up with Lazarus.

Reviewed By: bletham

Differential Revision: D13427014

fbshipit-source-id: 282bdca4e550e7b8efac69dee4ad9d77a986af3e
Summary: See title

Reviewed By: sdaulton

Differential Revision: D13552389

fbshipit-source-id: 68892a10fb58f88ecdcf174dcde202fb8c200302
Summary: See title

Reviewed By: sdaulton

Differential Revision: D13552393

fbshipit-source-id: 39b7ba2976cb8bdd9f50798612693ae42bb3aabb
Summary: Looks like so far the benchmarking code would ignore parameter constraints, this passes them on to the optimziation.

Reviewed By: sdaulton

Differential Revision: D13557896

fbshipit-source-id: a61e1abc896f26429cf71971adf8e4dd21da0990
Summary: Refit the model after exceptions to address cases where model fitting suceeds, but an exception is raised at prediction time. This is not as clean as I would like, but this keeps any observed data such that the benchmarks reflect our ability to retry after exceptions without re-evaluating in a real optimization scenario.

Reviewed By: Balandat

Differential Revision: D13560964

fbshipit-source-id: 7014e276c9081649fde6befb114f84d88db39735
…se (#3)

Summary:
Pull Request resolved: #3

Add ability to supply an arbitrary model to the benchmarks (so long as it has a likelihood so that we know how to refit it). Also, add support for observation noise.

Reviewed By: Balandat

Differential Revision: D13561279

fbshipit-source-id: 94bcf92819e9b6d0977b8a1f4ee43f6bce0c3e80
Summary:
This adjusts KG to use the fantasies provided by GPyTorch and
adds an option that utilizes the posterior mean instead of sampling for the
inner average.  This is appropriate when the objective is linear and there are no
constraints.

Reviewed By: Balandat

Differential Revision: D13595651

fbshipit-source-id: 24fcafca77ff45561297f51a12b46a477251d4cd
Summary:
Most of this is actually just copying Sobol over. This is meant to be temporary, the goal is to use the torch native version from pytorch/pytorch#10505 once the performance issues have been resolved.

The new content is doing qMC for standard Normal RVs, using the Box-Muller transform.

Reviewed By: bletham

Differential Revision: D13591551

fbshipit-source-id: cd57745ef837b1a88054bb6a55c760b553c144fa
Summary:
Simplify installation (and building of cython).

For now this requires both `numpy` and `cython` to be available at setup time. We could also ship generated `.c` files in the furture (see e.g.
https://cython.readthedocs.io/en/latest/src/userguide/source_files_and_compilation.html#distributing-cython-modules), but that is low-pri given that we want to move to the torch native Sobol engine from pytorch/pytorch#10505 (if that finally makes its way in...)

Reviewed By: bkarrer

Differential Revision: D13592395

fbshipit-source-id: 4bd2c7fada88dfad2bffce7c8b31dae3fca47e1f
…te generation

Summary:
This diff introduces 2 major changes:

1. **Adds a new heuristic for generating a set of (q-batch) initial conditions for candidate generation.**
This heuristic performs batch evaluation of the acquisition function to compute function values for all q-batches in supplied samples `X` (ideally qMC samples). It allows to utilize the `covar_module` (more generally, any similarity measure) of a model to compute covariances between points, from which it derives a permutation-invariant metric which is used to encourage "diversity" between points. At the end of the day, it samples from the input candidates according to weights that represent an "exploitation/exploration" tradeoff. This tradeoff is controlled via the `eta_Y` and `eta_sim` options (see docstring).
Computing the permutation invariant metric can be costly for large `q`, so this can be changed to an approximate permutation-invariant metric by limiting the comparison to a random subset of all possible permutations. By setting the `eta_sim` option to zero, similarity in feature space is fully ignored, in which case most of the computations are being avoided and the heuristic is fast.

2. **This diff also changes the way random restart optimization is being done**
It splits optimization from initialization, so that the `b x q x d`-batched set of initial conditions is passed in to `gen_candidates_scipy`. There is a lightweight helper function `get_best_candidates` that extracts the best of the optimized candidates. This is done in order to avoid having "random restart inflation", where we define a new optimization method for each initialization method.

These changes required a number of other changes throughout.

Reviewed By: bletham

Differential Revision: D13552294

fbshipit-source-id: cd1736df4d78e88ed2e28d9101ecc94b40498d83
Summary: see title

Reviewed By: bkarrer

Differential Revision: D13636398

fbshipit-source-id: a2bc70e098cea239c02598aece7d69eaa1c4dd47
Summary: `torch.nn.module.Module.load_state_dict` does not copy over tensor attributes, just the data, so the tensors in the new model (Module) must be set before calling `load_state_dict`.

Reviewed By: Balandat

Differential Revision: D13641897

fbshipit-source-id: b3eb43de20df95acc1015f0e545c75c4d1cb5063
Summary: see title

Reviewed By: Balandat

Differential Revision: D13653696

fbshipit-source-id: 295f81fb2b89e31795b34a74c9c0e67a302bfb6b
Summary:
- Use same initial model, initial training data, and initial conditions (via fixed seed)
- support acquisition_function_args
- Log runtime for each iteration

Reviewed By: Balandat

Differential Revision: D13672702

fbshipit-source-id: 0deea430683af4fb8f5d1241e7148f410057b865
Summary:
It's possible that the joint optimization is difficult and that a
sequential greedy optimization would work better for maximizing acquisition
functions.  This adds support for the sequential greedy optimization for
q-batch acquisition functions by leveraging the support for pending
observations.

It also adds a fine-tuning alternative that utilizes the sequential greedy optimization
as an initialization to the joint optimization.

We can potentially improve this later by special-casing the behavior for q=1
batch acquisition functions when closed forms are available.

Reviewed By: Balandat

Differential Revision: D13519985

fbshipit-source-id: 8a135971dd9365e32a1a5118216bbd88e3641d40
Summary:
Fixing the Sobol dimension in the previous implementation was wrong.
This corrects this using the algorithm from pg. 123 of:
G. Pages. Numerical Probability: An Introduction with Applications to Finance. Universitext. Springer International Publishing, 2018.

Reviewed By: bkarrer

Differential Revision: D13718713

fbshipit-source-id: cf646e85f8685ac29ce28556fc396326a416bf9a
Summary: See title

Reviewed By: bkarrer

Differential Revision: D13720137

fbshipit-source-id: 9af3e2eceb10fb9fd07285f881924b3155ae30e0
Summary: See title and discussion in T39305088

Reviewed By: bkarrer

Differential Revision: D13734605

fbshipit-source-id: 5c0564d2239ec808aa6e0fb97079ed53feef3597
Summary: Removes need to hardcode max dim in other parts of the code

Reviewed By: bkarrer

Differential Revision: D13735581

fbshipit-source-id: 2cac5b1c1a2b940c0f89ef3e3f26a51cc8bab981
Summary:
This alters base_samples to use QMC normal samples instead of normal
samples when the dimensionality is low enough.

Reviewed By: Balandat

Differential Revision: D13731853

fbshipit-source-id: 51f619f027b18cbd5e661b8253ac5e78e9018277
Summary:
Autograd did not get along with the `fill_` call in fix feature.
Alos, if we don't modify the tensor in-place, we have to do call fix_features once more in `gen_candidates`.

Reviewed By: sdaulton

Differential Revision: D13792864

fbshipit-source-id: c60c088080bcb62b30add444825ed59c7c80d354
Summary: This adds unit tests for fixed features for scipy candidate generation, and fixes a bug where the gradient was computed improperly when using fixed features.

Reviewed By: Balandat

Differential Revision: D13517500

fbshipit-source-id: b6587e6caf34c626d2973594e817c1a7c70092fc
Summary:
Includes the following changes:
- the `likelihood` argument for `SingleTaskGP` is now optional - if not provided, use a reasonable default (`GaussianLikelihood` with reasonable prior)
- the `reinitialize` function has gained a new `keep_params` argument. If `True`
  (default), reinitializing does not reset the model parameters (this will be particularly
  useful for speeding up refitting models in closed-loop optimization)
- the transforms for the noise are modified to provide a lower bound on the
  noise level, which should help significantly with numerical stability

Reviewed By: sdaulton

Differential Revision: D13815248

fbshipit-source-id: 5a5d7f2a8c270045597c947f73fa7fa0e9df4ad9
Summary:
In addition to the user_supplied bounds dict for fit_model, the optimization now checks for a `parameter_bounds` dictionary on the model of the same type.
This allows specifying the bounds with the model instead of manually passing them in everywhere. Ideally, this funcitonality would be added upstream
in gpytorch so that bounds can be passed in upon construction (which would also allow defining the bounds on the module to which the parameter belongs,
and hence enable traversing the module tree recursively like we do with parameters etc.)

I will add reasonable default parameter constraints to the standard botorch models in a separate diff.

Reviewed By: sdaulton

Differential Revision: D13823751

fbshipit-source-id: 2fc494ce95cc817674cb577ddb56fc313cb337e2
Summary: This fixes an issue where gpytorch.Priors are not properly deepcopied.

Reviewed By: Balandat

Differential Revision: D13821358

fbshipit-source-id: 1280d454d0b634eabec19956527bfd5ff7f67969
eytan and others added 18 commits April 17, 2019 23:46
Summary:
Revised documentation / narrative on ax integration.
Pull Request resolved: #71

Reviewed By: Balandat

Differential Revision: D14994033

Pulled By: eytan

fbshipit-source-id: 45dd4a0ce44c3bd5870e4f74fec8affe90ae1b53
Summary: -

Reviewed By: danielrjiang

Differential Revision: D14994297

fbshipit-source-id: fdccbd9dd7010111c1da4c2db5c26ca338a4143f
Summary: --

Reviewed By: danielrjiang

Differential Revision: D14994298

fbshipit-source-id: 9a51eb7a718417a196160900c4a55f18be4c21f3
Summary: see title

Reviewed By: danielrjiang

Differential Revision: D15003889

fbshipit-source-id: 8e222dbca6b6fe35ca54896cb82ec2589ad64f67
Summary: see title

Reviewed By: danielrjiang

Differential Revision: D15004804

fbshipit-source-id: 382fa90d9f7041eecbb910277a179601663df636
Summary:
- Use batched multioutput gp when possible: single task and either (a) single output or (b) same training data for all outputs
- remove `reinitialize` method, we can just use `load_state_dict`

Reviewed By: Balandat

Differential Revision: D15009564

fbshipit-source-id: 43cff9cc8ca2de081a8546e2f09d7282d59a9f7e
Summary:
Clone of D15024829. FIxes issue of nb not being downloadable.

Pull Request resolved: #72

Reviewed By: kkashin

Differential Revision: D15025428

Pulled By: Balandat

fbshipit-source-id: 297ebb45c3d3ff4476cf7654b00e9d796f68ab53
Summary: We're using BoTorch going forward.

Reviewed By: danielrjiang

Differential Revision: D15028048

fbshipit-source-id: 8a85db26f83578c5d2efa1f3ca04a440a617b0bb
Summary:
Pull Request resolved: #73

Doesn't make a lot of sense to have this blog. We can use the
AI/PyTorch blogs and other blogs instead.

Reviewed By: kkashin

Differential Revision: D15028058

fbshipit-source-id: 60652cf7c0b73dae2238595182c9647f9e30be77
Summary:
Makes capitalization consistent. Avoids code-formatting with backticks in
headings.

Reviewed By: danielrjiang

Differential Revision: D15028343

fbshipit-source-id: fce40fce4362295b215af0e383ec801e0c93a18e
Summary: Breaks flake8 tests in travis

Reviewed By: danielrjiang

Differential Revision: D15028064

fbshipit-source-id: 5627a054a55dc30b389cd01d360bc4132a0ee3c1
Summary:
Pull Request resolved: #74

Let's test against this commit that will likely become 0.3.2.
Also adds a "test" installation so that we don't install all the notebook stuff when testing in contbuild.

Reviewed By: danielrjiang

Differential Revision: D15037922

fbshipit-source-id: e6eea4920509cca931d194ce372a6ab53cd6d419
Summary: --

Reviewed By: Balandat

Differential Revision: D15035238

fbshipit-source-id: 07d30e03972c58a97fc4975a8a7581e4e67efca9
Summary:
Moves `initialize_q_batch` to botorch_fb as
`initialize_q_batacH-complex` since it's not used
currenlty. Renames `initialize_q_batch_simple` to `initialize_q_batch` in
botorch. This mostly cleans things up to make them less confusing, will stack
another diff that improves heuristics and functionality

Reviewed By: danielrjiang

Differential Revision: D15025582

fbshipit-source-id: eaab064462b8a52b1f4ff467b76d2839123007b9
…#75)

Summary:
Gets rid of the `OptimizeWarnings` issued in newer scipy versions if an unknown option is passed in.
Pull Request resolved: #75

Reviewed By: danielrjiang

Differential Revision: D15039022

Pulled By: Balandat

fbshipit-source-id: a1a74f8607b306a5f8018ea27968ff6f71a78997
Summary: now uses ModelListGP and load state dict instead of reinitialize

Reviewed By: Balandat

Differential Revision: D15039367

fbshipit-source-id: ba428c5325f5cfda74316ebc8a2d4fbe4f85fe09
Summary: --

Reviewed By: Balandat

Differential Revision: D15043384

fbshipit-source-id: 53365d388f369c07645c2eaa9a37dfa0f1783971
@facebook-github-bot facebook-github-bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Apr 24, 2019
r"""
"""
if sampler is None:
sampler = IIDNormalSampler(num_samples=16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is pretty much no reason to ever not use SobolQMCNormalSampler in place of IIDNormalSampler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually we really just want to use SobolQMCNormalSampler if we are using the samples for monte carlo integration. If we are doing something different with the samples (like taking their max), then we should use the real distribution.

self.sampler = sampler
self.register_buffer(
"candidate_set",
torch.rand(candidate_set_size, bounds.size(1), device=bounds.device, dtype=bounds.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might also want to use Sobol sequences here as well. See the docs for SobolQMCNormalSampler for sample usage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So SobolQMCNormalSampler is used for sampling from a posterior. Here the sampling is just from the feature space, so you should use SobolEngine from botorch.qmc.sobol instead. Note that I'm planning to land #55 soon, which will remove the botorch SobolEngine and use the one from torch.quasirandom instead (that is basically the same engine, just upstreamed to ATen).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason not to use torch.quasirandom directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.quasirandom is just the module. It provides a SobolEngine that you can use to draw samples from the unit cube. Btw, I just merged #55.

def candidate_set_argmax(self):
r"""
"""
if self._candidate_set_argmax is None:
Copy link
Contributor

@eytan eytan Apr 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is preventing you from initializing self._candidate_set_argmax to self.candidate_set_max_values() at the end of the constructor? Then you wouldn't need either of these methods. Perhaps I am missing something here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this leads to "lazy" evaluation where the compute is done upon the first evaluation of the acquisition function rather than upon its construction. I don't have strong feelings about where this should happen - would be interested in seeing the relative timings.

@eytan
Copy link
Contributor

eytan commented Apr 25, 2019

OK, last comment: since MES is a non-myopic search strategy, does it make sense to also consider the value of the "best x" according to the model (e.g., what we would pick according to the posterior mean), rather than the best observed so far? I don't recall what quantity they plot in the MES paper, but IIRC in the KG papers they often plot the value of the best x according to the model, rather than the best tried so far.

Copy link
Contributor

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good, some minor comments on the implementation to start with. Haven't had time to look through the nortebook.

stdv = posterior.variance.sqrt()

normalized_mvs = (self.candidate_set_max_values() - mean) / stdv
normal = torch.distributions.Normal(torch.zeros_like(normalized_mvs), torch.ones_like(normalized_mvs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it not possible to batch-evaluate the pdf of a single N(0, 1), rather than constructing a full batched Normal? Seems like a waste do this.

def candidate_set_argmax(self):
r"""
"""
if self._candidate_set_argmax is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this leads to "lazy" evaluation where the compute is done upon the first evaluation of the acquisition function rather than upon its construction. I don't have strong feelings about where this should happen - would be interested in seeing the relative timings.

@eytan
Copy link
Contributor

eytan commented Apr 25, 2019

I ran into an issue where the notebook craps out after about ~ 30-80 iterations if you set NUM_BATCHES to say, 100, because initialize_q_batch() is receiving a Y tensor that contains a nan. I wonder if there is something that is causing cdf in MaxValueEntropySearch.forward() to be non-positive.

@Balandat
Copy link
Contributor

Seems like this could be caused by extreme values in the pdf/cdf evaluation. I’ll take a look.

@Balandat
Copy link
Contributor

@eytan regarding you running into dimensionality restrictions with sobol: When computing the max values from the candidate set we do sample from a candidate_set_size-dimensional MVN, so if candidate_set_size is larger than 1111 the SobolSampler currently does not support this. I'm assuming sampling from the joint here is very much what we want to do.

@Balandat
Copy link
Contributor

I cleaned this up a little bit here

I changed the constructor to take a candidate set tensor rather than bounds, that way it's much easier to use differetnt strategies for constructing that candidate set (maybe you have to adhere to parameter constraints etc.). We can make a simple helper function to generate one from bounds.

I also construct a single Normal and batch evaluate it rather than a multi-dimensional normal.

@gpleiss, @jacobrgardner, One point of feedback that would be great to get is regarding the q-batch-transform - we introduced this to make the analytic functions more accessible in the sense that people don't manually add a q-batch dimension to the input (they don't need really need to know about that really). The downside is that this changes the interface to the acquisition functions so you have to be careful to write your optimization loop to account for that. I think right now I favor making the interface consistent, and always require a q-batch when evaluating an acquisition function (and just error out if the size of that is not 1. Any thoughts on this?

@eytan
Copy link
Contributor

eytan commented Apr 25, 2019

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed Do not delete this pull request or issue due to inactivity.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants