Skip to content

Commit 58991c4

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Remove HeteroskedasticSingleTaskGP (#2616)
Summary: This model has been buggy for quite a long time and we still haven't fixed it. Removing it should be preferable to keeping around a known buggy model. Example bug reports: - #861 - #933 - #2551 Differential Revision: D65543676
1 parent 8b13120 commit 58991c4

File tree

8 files changed

+24
-251
lines changed

8 files changed

+24
-251
lines changed

botorch/acquisition/joint_entropy_search.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,10 @@
2929
from botorch import settings
3030
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
3131
from botorch.acquisition.objective import PosteriorTransform
32-
3332
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
34-
from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL
3533
from botorch.models.model import Model
36-
3734
from botorch.models.utils import check_no_nans, fantasize as fantasize_flag
35+
from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL
3836
from botorch.sampling.normal import SobolQMCNormalSampler
3937
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
4038
from torch import Tensor

botorch/models/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
1818
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
1919

20-
from botorch.models.gp_regression import HeteroskedasticSingleTaskGP, SingleTaskGP
20+
from botorch.models.gp_regression import SingleTaskGP
2121
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
2222
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
2323
from botorch.models.higher_order_gp import HigherOrderGP
@@ -33,7 +33,6 @@
3333
"SaasFullyBayesianSingleTaskGP",
3434
"SaasFullyBayesianMultiTaskGP",
3535
"GenericDeterministicModel",
36-
"HeteroskedasticSingleTaskGP",
3736
"HigherOrderGP",
3837
"KroneckerMultiTaskGP",
3938
"MixedSingleTaskGP",

botorch/models/converter.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from botorch.exceptions import UnsupportedError
1818
from botorch.exceptions.warnings import BotorchWarning
1919
from botorch.models import SingleTaskGP
20-
from botorch.models.gp_regression import HeteroskedasticSingleTaskGP
2120
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
2221
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
2322
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
@@ -84,12 +83,6 @@ def _check_compatibility(models: ModuleList) -> None:
8483
"All models must be of type BatchedMultiOutputGPyTorchModel."
8584
)
8685

87-
# TODO: Add support for HeteroskedasticSingleTaskGP.
88-
if any(isinstance(m, HeteroskedasticSingleTaskGP) for m in models):
89-
raise NotImplementedError(
90-
"Conversion of HeteroskedasticSingleTaskGP is currently unsupported."
91-
)
92-
9386
# TODO: Add support for custom likelihoods.
9487
if any(getattr(m, "_is_custom_likelihood", False) for m in models):
9588
raise NotImplementedError(
@@ -289,11 +282,6 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model
289282
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
290283
was_training = batch_model.training
291284
batch_model.train()
292-
# TODO: Add support for HeteroskedasticSingleTaskGP.
293-
if isinstance(batch_model, HeteroskedasticSingleTaskGP):
294-
raise NotImplementedError(
295-
"Conversion of HeteroskedasticSingleTaskGP is currently not supported."
296-
)
297285
if isinstance(batch_model, MixedSingleTaskGP):
298286
raise NotImplementedError(
299287
"Conversion of MixedSingleTaskGP is currently not supported."
@@ -393,12 +381,7 @@ def batched_multi_output_to_single_output(
393381
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
394382
was_training = batch_mo_model.training
395383
batch_mo_model.train()
396-
# TODO: Add support for HeteroskedasticSingleTaskGP.
397-
if isinstance(batch_mo_model, HeteroskedasticSingleTaskGP):
398-
raise NotImplementedError(
399-
"Conversion of HeteroskedasticSingleTaskGP currently not supported."
400-
)
401-
elif not isinstance(batch_mo_model, BatchedMultiOutputGPyTorchModel):
384+
if not isinstance(batch_mo_model, BatchedMultiOutputGPyTorchModel):
402385
raise UnsupportedError("Only BatchedMultiOutputGPyTorchModels are supported.")
403386
# TODO: Add support for custom likelihoods.
404387
elif getattr(batch_mo_model, "_is_custom_likelihood", False):

botorch/models/gp_regression.py

Lines changed: 15 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -10,58 +10,47 @@
1010
These models are often a good starting point and are further documented in the
1111
tutorials.
1212
13-
`SingleTaskGP` and `HeteroskedasticSingleTaskGP` are single-task exact GP models,
14-
differing in how they treat noise. They use relatively strong priors on the Kernel
15-
hyperparameters, which work best when covariates are normalized to the unit cube
16-
and outcomes are standardized (zero mean, unit variance). By default, these models
17-
use a `Standardize` outcome transform, which applies this standardization. However,
18-
they do not (yet) use an input transform by default.
19-
20-
These models all work in batch mode (each batch having its own hyperparameters).
21-
When the training observations include multiple outputs, these models use
13+
`SingleTaskGP` is a single-task exact GP model that uses relatively strong priors on
14+
the Kernel hyperparameters, which work best when covariates are normalized to the unit
15+
cube and outcomes are standardized (zero mean, unit variance). By default, this model
16+
uses a `Standardize` outcome transform, which applies this standardization. However,
17+
it does not (yet) use an input transform by default.
18+
19+
`SingleTaskGP` model works in batch mode (each batch having its own hyperparameters).
20+
When the training observations include multiple outputs, `SingleTaskGP` uses
2221
batching to model outputs independently.
2322
24-
These models all support multiple outputs. However, as single-task models,
25-
`SingleTaskGP` and `HeteroskedasticSingleTaskGP` should be used only when the
26-
outputs are independent and all use the same training data. If outputs are
27-
independent and outputs have different training data, use the `ModelListGP`.
28-
When modeling correlations between outputs, use a multi-task model like `MultiTaskGP`.
23+
`SingleTaskGP` supports multiple outputs. However, as a single-task model,
24+
`SingleTaskGP` should be used only when the outputs are independent and all
25+
use the same training inputs. If outputs are independent but they have different
26+
training inputs, use the `ModelListGP`. When modeling correlations between outputs,
27+
use a multi-task model like `MultiTaskGP`.
2928
"""
3029

3130
from __future__ import annotations
3231

3332
import warnings
34-
from typing import NoReturn
3533

3634
import torch
3735
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
3836
from botorch.models.model import FantasizeMixin
3937
from botorch.models.transforms.input import InputTransform
40-
from botorch.models.transforms.outcome import Log, OutcomeTransform, Standardize
38+
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
4139
from botorch.models.utils import validate_input_scaling
4240
from botorch.models.utils.gpytorch_modules import (
4341
get_covar_module_with_dim_scaled_prior,
4442
get_gaussian_likelihood_with_lognormal_prior,
45-
MIN_INFERRED_NOISE_LEVEL,
4643
)
4744
from botorch.utils.containers import BotorchContainer
4845
from botorch.utils.datasets import SupervisedDataset
4946
from botorch.utils.types import _DefaultType, DEFAULT
50-
from gpytorch.constraints.constraints import GreaterThan
5147
from gpytorch.distributions.multivariate_normal import MultivariateNormal
52-
from gpytorch.likelihoods.gaussian_likelihood import (
53-
_GaussianLikelihoodBase,
54-
FixedNoiseGaussianLikelihood,
55-
GaussianLikelihood,
56-
)
48+
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
5749
from gpytorch.likelihoods.likelihood import Likelihood
58-
from gpytorch.likelihoods.noise_models import HeteroskedasticNoise
5950
from gpytorch.means.constant_mean import ConstantMean
6051
from gpytorch.means.mean import Mean
61-
from gpytorch.mlls.noise_model_added_loss_term import NoiseModelAddedLossTerm
6252
from gpytorch.models.exact_gp import ExactGP
6353
from gpytorch.module import Module
64-
from gpytorch.priors.smoothed_box_prior import SmoothedBoxPrior
6554
from torch import Tensor
6655

6756

@@ -253,107 +242,3 @@ def forward(self, x: Tensor) -> MultivariateNormal:
253242
mean_x = self.mean_module(x)
254243
covar_x = self.covar_module(x)
255244
return MultivariateNormal(mean_x, covar_x)
256-
257-
258-
class HeteroskedasticSingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP):
259-
r"""A single-task exact GP model using a heteroskedastic noise model.
260-
261-
This model differs from `SingleTaskGP` with observed observation noise
262-
variances (`train_Yvar`) in that it can predict noise levels out of sample.
263-
This is achieved by internally wrapping another GP (a `SingleTaskGP`) to model
264-
the (log of) the observation noise. Noise levels must be provided to
265-
`HeteroskedasticSingleTaskGP` as `train_Yvar`.
266-
267-
Examples of cases in which noise levels are known include online
268-
experimentation and simulation optimization.
269-
270-
Example:
271-
>>> train_X = torch.rand(20, 2)
272-
>>> train_Y = torch.sin(train_X).sum(dim=1, keepdim=True)
273-
>>> se = torch.linalg.norm(train_X, dim=1, keepdim=True)
274-
>>> train_Yvar = 0.1 + se * torch.rand_like(train_Y)
275-
>>> model = HeteroskedasticSingleTaskGP(train_X, train_Y, train_Yvar)
276-
"""
277-
278-
def __init__(
279-
self,
280-
train_X: Tensor,
281-
train_Y: Tensor,
282-
train_Yvar: Tensor,
283-
outcome_transform: OutcomeTransform | None = None,
284-
input_transform: InputTransform | None = None,
285-
) -> None:
286-
r"""
287-
Args:
288-
train_X: A `batch_shape x n x d` tensor of training features.
289-
train_Y: A `batch_shape x n x m` tensor of training observations.
290-
train_Yvar: A `batch_shape x n x m` tensor of observed measurement
291-
noise.
292-
outcome_transform: An outcome transform that is applied to the
293-
training data during instantiation and to the posterior during
294-
inference (that is, the `Posterior` obtained by calling
295-
`.posterior` on the model will be on the original scale).
296-
Note that the noise model internally log-transforms the
297-
variances, which will happen after this transform is applied.
298-
input_transform: An input transfrom that is applied in the model's
299-
forward pass.
300-
"""
301-
if outcome_transform is not None:
302-
train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar)
303-
self._validate_tensor_args(X=train_X, Y=train_Y, Yvar=train_Yvar)
304-
validate_input_scaling(train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar)
305-
self._set_dimensions(train_X=train_X, train_Y=train_Y)
306-
noise_likelihood = GaussianLikelihood(
307-
noise_prior=SmoothedBoxPrior(-3, 5, 0.5, transform=torch.log),
308-
batch_shape=self._aug_batch_shape,
309-
noise_constraint=GreaterThan(
310-
MIN_INFERRED_NOISE_LEVEL, transform=None, initial_value=1.0
311-
),
312-
)
313-
# Likelihood will always get evaluated with transformed X, so we need to
314-
# transform the training data before constructing the noise model.
315-
with torch.no_grad():
316-
transformed_X = self.transform_inputs(
317-
X=train_X, input_transform=input_transform
318-
)
319-
noise_model = SingleTaskGP(
320-
train_X=transformed_X,
321-
train_Y=train_Yvar,
322-
likelihood=noise_likelihood,
323-
outcome_transform=Log(),
324-
)
325-
likelihood = _GaussianLikelihoodBase(HeteroskedasticNoise(noise_model))
326-
# This is hacky -- this class used to inherit from SingleTaskGP, but it
327-
# shouldn't so this is a quick fix to enable getting rid of that
328-
# inheritance
329-
SingleTaskGP.__init__(
330-
# pyre-fixme[6]: Incompatible parameter type
331-
self,
332-
train_X=train_X,
333-
train_Y=train_Y,
334-
likelihood=likelihood,
335-
outcome_transform=None,
336-
input_transform=input_transform,
337-
)
338-
self.register_added_loss_term("noise_added_loss")
339-
self.update_added_loss_term(
340-
"noise_added_loss", NoiseModelAddedLossTerm(noise_model)
341-
)
342-
if outcome_transform is not None:
343-
self.outcome_transform = outcome_transform
344-
self.to(train_X)
345-
346-
# pyre-fixme[15]: Inconsistent override
347-
def condition_on_observations(self, *_, **__) -> NoReturn:
348-
raise NotImplementedError
349-
350-
# pyre-fixme[15]: Inconsistent override
351-
def subset_output(self, idcs) -> NoReturn:
352-
raise NotImplementedError
353-
354-
def forward(self, x: Tensor) -> MultivariateNormal:
355-
if self.training:
356-
x = self.transform_inputs(x)
357-
mean_x = self.mean_module(x)
358-
covar_x = self.covar_module(x)
359-
return MultivariateNormal(mean_x, covar_x)

botorch_community/acquisition/scorebo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
)
3232
from botorch.acquisition.objective import ScalarizedPosteriorTransform
3333
from botorch.models.fully_bayesian import MCMC_DIM, SaasFullyBayesianSingleTaskGP
34-
from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL
3534
from botorch.models.utils import fantasize as fantasize_flag
35+
from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL
3636
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
3737
from botorch_community.acquisition.bayesian_active_learning import DISTANCE_METRICS
3838
from torch import Tensor

docs/models.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ Noise can be treated in several different ways:
7272
if you know your observations are noiseless (by passing a zero noise level).
7373

7474
- _Heteroskedastic_: Noise is provided as an input and is modeled to allow for
75-
predicting noise out-of-sample. Models like `HeteroskedasticSingleTaskGP` take
76-
this approach.
75+
predicting noise out-of-sample. BoTorch does not implement a model that
76+
supports this out of the box.
7777

7878
## Standard BoTorch Models
7979

@@ -90,10 +90,6 @@ instead.
9090
- [`SingleTaskGP`](../api/models.html#botorch.models.gp_regression.SingleTaskGP):
9191
a single-task exact GP that supports both inferred and observed noise. When
9292
noise observations are not provided, it infers a homoskedastic noise level.
93-
- [`HeteroskedasticSingleTaskGP`](../api/models.html#botorch.models.gp_regression.HeteroskedasticSingleTaskGP):
94-
a single-task exact GP that differs from `SingleTaskGP` with observed noise in
95-
that it models heteroskedastic noise using an additional internal GP model. It
96-
requires noise observations.
9793
- [`MixedSingleTaskGP`](../api/models.html#botorch.models.gp_regression_mixed.MixedSingleTaskGP):
9894
a single-task exact GP that supports mixed search spaces, which combine
9995
discrete and continuous features.

test/models/test_converter.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@
77

88
import torch
99
from botorch.exceptions import UnsupportedError
10-
from botorch.models import (
11-
HeteroskedasticSingleTaskGP,
12-
ModelListGP,
13-
SingleTaskGP,
14-
SingleTaskMultiFidelityGP,
15-
)
10+
from botorch.models import ModelListGP, SingleTaskGP, SingleTaskMultiFidelityGP
1611
from botorch.models.converter import (
1712
_batched_kernel,
1813
batched_multi_output_to_single_output,
@@ -58,12 +53,6 @@ def test_batched_to_model_list(self):
5853
)
5954
list_gp = batched_to_model_list(batch_gp)
6055
self.assertIsInstance(list_gp, ModelListGP)
61-
# test HeteroskedasticSingleTaskGP
62-
batch_gp = HeteroskedasticSingleTaskGP(
63-
train_X, train_Y, torch.rand_like(train_Y)
64-
)
65-
with self.assertRaises(NotImplementedError):
66-
batched_to_model_list(batch_gp)
6756
# test with transforms
6857
input_tf = Normalize(
6958
d=2,
@@ -161,12 +150,6 @@ def test_model_list_to_batched(self):
161150
)
162151
with self.assertRaises(UnsupportedError):
163152
model_list_to_batched(ModelListGP(gp1, gp2))
164-
# test HeteroskedasticSingleTaskGP
165-
gp2 = HeteroskedasticSingleTaskGP(
166-
train_X, train_Y1, torch.ones_like(train_Y1)
167-
)
168-
with self.assertRaises(NotImplementedError):
169-
model_list_to_batched(ModelListGP(gp2))
170153
# test custom likelihood
171154
gp2 = SingleTaskGP(
172155
train_X,
@@ -419,11 +402,6 @@ def test_batched_multi_output_to_single_output(self):
419402
non_batch_model = SimpleGPyTorchModel(train_X, train_Y[:, :1])
420403
with self.assertRaises(UnsupportedError):
421404
batched_multi_output_to_single_output(non_batch_model)
422-
gp2 = HeteroskedasticSingleTaskGP(
423-
train_X, train_Y, torch.ones_like(train_Y)
424-
)
425-
with self.assertRaises(NotImplementedError):
426-
batched_multi_output_to_single_output(gp2)
427405
# test custom likelihood
428406
gp2 = SingleTaskGP(train_X, train_Y, likelihood=GaussianLikelihood())
429407
with self.assertRaises(NotImplementedError):

0 commit comments

Comments
 (0)