Skip to content

Apply input transforms when computing MLL in model closures #2527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,17 @@ def __init__(
MIN_INFERRED_NOISE_LEVEL, transform=None, initial_value=1.0
),
)
# Likelihood will always get evaluated with transformed X, so we need to
# transform the training data before constructing the noise model.
with torch.no_grad():
transformed_X = self.transform_inputs(
X=train_X, input_transform=input_transform
)
noise_model = SingleTaskGP(
train_X=train_X,
train_X=transformed_X,
train_Y=train_Yvar,
likelihood=noise_likelihood,
outcome_transform=Log(),
input_transform=input_transform,
)
likelihood = _GaussianLikelihoodBase(HeteroskedasticNoise(noise_model))
# This is hacky -- this class used to inherit from SingleTaskGP, but it
Expand Down
27 changes: 21 additions & 6 deletions botorch/optim/closures/model_closures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from __future__ import annotations

from collections.abc import Sequence

from itertools import chain, repeat
from types import NoneType
from typing import Any, Callable, Optional
Expand Down Expand Up @@ -174,9 +173,17 @@ def _get_loss_closure_exact_internal(
r"""ExactMarginalLogLikelihood loss closure with internally managed data."""

def closure(**kwargs: Any) -> Tensor:
model_output = mll.model(*mll.model.train_inputs)
model = mll.model
# The inputs will get transformed in forward here.
model_output = model(*model.train_inputs)
log_likelihood = mll(
model_output, mll.model.train_targets, *mll.model.train_inputs, **kwargs
model_output,
model.train_targets,
# During model training, the model inputs get transformed in the forward
# pass. The train_inputs property is not transformed yet, so we need to
# transform it before passing it to the likelihood for consistency.
*(model.transform_inputs(X=t_in) for t_in in model.train_inputs),
**kwargs,
)
return -log_likelihood

Expand All @@ -190,11 +197,19 @@ def _get_loss_closure_sum_internal(
r"""SumMarginalLogLikelihood loss closure with internally managed data."""

def closure(**kwargs: Any) -> Tensor:
model_output = mll.model(*mll.model.train_inputs)
model = mll.model
# The inputs will get transformed in forward here.
model_output = model(*model.train_inputs)
log_likelihood = mll(
model_output,
mll.model.train_targets,
*map(list, mll.model.train_inputs),
model.train_targets,
# During model training, the model inputs get transformed in the forward
# pass. The train_inputs property is not transformed yet, so we need to
# transform it before passing it to the likelihood for consistency.
*(
(model.transform_inputs(X=t_in) for t_in in sub_t_in)
for sub_t_in in model.train_inputs
),
**kwargs,
)
return -log_likelihood
Expand Down
120 changes: 94 additions & 26 deletions test/optim/closures/test_model_closures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,68 @@
)
from botorch.utils.testing import BotorchTestCase
from gpytorch import settings as gpytorch_settings
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood, SumMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.module import Module
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset


# Mock wrapping the __call__ directly is leading to errors like
# TypeError: super(type, obj): obj must be an instance or subtype of type
# so, doing this manually here.
class WrapperLikelihood(GaussianLikelihood):
def __init__(self, base_likelihood: GaussianLikelihood):
"""A wrapper around a GaussianLikelihood that stores the call args."""
Module.__init__(self)
self.base_likelihood = base_likelihood
self.call_args = []

def __call__(self, *args, **kwargs):
# Store the train inputs arg for testing.
self.call_args.append(args[1])
return self.base_likelihood(*args, **kwargs)


def _get_mlls(
device: torch.device, wrap_likelihood: bool = False
) -> tuple[Tensor, list[MarginalLogLikelihood]]:
"""Returns the train X, along two MLLs: one for a SingleTaskGP and
one for a ModelListGP.

Args:
device: The device to use.
wrap_likelihood: If True, wrap the likelihood in a WrapperLikelihood.
This is useful for comparing call args later.
"""
with torch.random.fork_rng():
torch.manual_seed(0)
# Inputs are not in the unit cube to ensure input transform is applied.
train_X = torch.linspace(0, 5, 10).unsqueeze(-1)
train_Y = torch.sin((2 * pi) * train_X)
train_Y = train_Y + 0.1 * torch.randn_like(train_Y)
mlls = []
model = SingleTaskGP(
train_X=train_X,
train_Y=train_Y,
input_transform=Normalize(d=1),
outcome_transform=Standardize(m=1),
)
if wrap_likelihood:
model.likelihood = WrapperLikelihood(model.likelihood)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
mlls.append(mll.to(device=device, dtype=torch.double))

model = ModelListGP(model, model)
mll = SumMarginalLogLikelihood(model.likelihood, model)
mlls.append(mll.to(device=device, dtype=torch.double))
return train_X.to(device=device, dtype=torch.double), mlls


class TestLossClosures(BotorchTestCase):
def setUp(self):
super().setUp()
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
train_Y = torch.sin((2 * pi) * train_X)
train_Y = train_Y + 0.1 * torch.randn_like(train_Y)

self.mlls = {}
model = SingleTaskGP(
train_X=train_X,
train_Y=train_Y,
input_transform=Normalize(d=1),
outcome_transform=Standardize(m=1),
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
self.mlls[type(mll), type(model.likelihood), type(model)] = mll.to(self.device)

model = ModelListGP(model, model)
mll = SumMarginalLogLikelihood(model.likelihood, model)
self.mlls[type(mll), type(model.likelihood), type(model)] = mll.to(self.device)

def test_main(self):
for mll in self.mlls.values():
def test_main(self) -> None:
for mll in _get_mlls(device=self.device)[1]:
out = mll.model(*mll.model.train_inputs)
loss = -mll(out, mll.model.train_targets).sum()
loss.backward()
Expand All @@ -63,8 +96,8 @@ def test_main(self):
self.assertTrue(loss.equal(_loss))
self.assertTrue(all(a.equal(b) for a, b in zip_longest(grads, _grads)))

def test_data_loader(self):
for mll in self.mlls.values():
def test_data_loader(self) -> None:
for mll in _get_mlls(device=self.device)[1]:
if type(mll) is not ExactMarginalLogLikelihood:
continue

Expand All @@ -86,3 +119,38 @@ def test_data_loader(self):
closure = get_loss_closure_with_grads(mll, params, data_loader=loader)
with self.assertRaisesRegex(TypeError, "Expected .* a batch of tensors"):
closure()

def test_with_input_transforms(self) -> None:
# This test reproduces the bug reported in issue #2515.
train_X, mlls = _get_mlls(device=self.device, wrap_likelihood=True)
for mll in mlls:
if isinstance(mll, SumMarginalLogLikelihood):
# The likelihood is called twice here since it is the same
# likelihood in both child models.
likelihood = mll.model.models[0].likelihood
expected_calls1 = 2 # In the closure call.
expected_calls2 = 6 # Closure + posterior calls.
else:
likelihood = mll.model.likelihood
expected_calls1 = 1 # In the closure call.
expected_calls2 = 4 # Closure + posterior calls.
likelihood.call_args = [] # reset since it is shared between the models.
params = {n: p for n, p in mll.named_parameters() if p.requires_grad}
# Evaluate the closure to mimic the model fitting process.
mll.train()
closure = get_loss_closure_with_grads(mll, params)
closure()
self.assertEqual(len(likelihood.call_args), expected_calls1)
# Call the model posterior to reproduce post-fitting usage.
mll.model.posterior(train_X, observation_noise=True)
# Compare the call args to ensure they're all the same.
# Likelihood is called twice on model(X) and once for adding the noise.
self.assertEqual(len(likelihood.call_args), expected_calls2)
arg0 = likelihood.call_args[0]
for i in range(1, expected_calls2):
argi = likelihood.call_args[i]
# The arg may be a tensor or a single element list of the tensor.
self.assertAllClose(
arg0 if isinstance(arg0, Tensor) else arg0[0],
argi if isinstance(argi, Tensor) else argi[0],
)
Loading