Skip to content

Commit 83fc1c2

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Apply input transforms when computing MLL in model closures
Summary: During model training, the input transforms are applied in `model.forward`. While evaluating the model closures, we pass in the train inputs to the `mll`, which passed them down to the `likelihood`. If we don't transform the inputs before passing them into `mll`, we end up evaluating `model.forward` and `likelihood` using different inputs. This is not an issue during the `posterior` evaluation, since the transforms are applied in `model.posterior` before being passed to `model.__call__` and `likelihood`. This diff updates the model closures to transform the inputs before passing them into `mll`. Fixes #2515 Differential Revision: D62497392
1 parent 4d49bf7 commit 83fc1c2

File tree

2 files changed

+114
-32
lines changed

2 files changed

+114
-32
lines changed

botorch/optim/closures/model_closures.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from __future__ import annotations
1010

1111
from collections.abc import Sequence
12-
1312
from itertools import chain, repeat
1413
from types import NoneType
1514
from typing import Any, Callable, Optional
@@ -174,9 +173,17 @@ def _get_loss_closure_exact_internal(
174173
r"""ExactMarginalLogLikelihood loss closure with internally managed data."""
175174

176175
def closure(**kwargs: Any) -> Tensor:
177-
model_output = mll.model(*mll.model.train_inputs)
176+
model = mll.model
177+
# The inputs will get transformed in forward here.
178+
model_output = model(*model.train_inputs)
178179
log_likelihood = mll(
179-
model_output, mll.model.train_targets, *mll.model.train_inputs, **kwargs
180+
model_output,
181+
model.train_targets,
182+
# During model training, the model inputs get transformed in the forward pass.
183+
# The train_inputs property is not transformed yet, so we need to transform
184+
# it before passing it to the likelihood for consistency.
185+
*(model.transform_inputs(X=t_in) for t_in in model.train_inputs),
186+
**kwargs,
180187
)
181188
return -log_likelihood
182189

@@ -190,11 +197,19 @@ def _get_loss_closure_sum_internal(
190197
r"""SumMarginalLogLikelihood loss closure with internally managed data."""
191198

192199
def closure(**kwargs: Any) -> Tensor:
193-
model_output = mll.model(*mll.model.train_inputs)
200+
model = mll.model
201+
# The inputs will get transformed in forward here.
202+
model_output = model(*model.train_inputs)
194203
log_likelihood = mll(
195204
model_output,
196-
mll.model.train_targets,
197-
*map(list, mll.model.train_inputs),
205+
model.train_targets,
206+
# During model training, the model inputs get transformed in the forward pass.
207+
# The train_inputs property is not transformed yet, so we need to transform
208+
# it before passing it to the likelihood for consistency.
209+
*(
210+
(model.transform_inputs(X=t_in) for t_in in sub_t_in)
211+
for sub_t_in in model.train_inputs
212+
),
198213
**kwargs,
199214
)
200215
return -log_likelihood

test/optim/closures/test_model_closures.py

Lines changed: 93 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,67 @@
1717
)
1818
from botorch.utils.testing import BotorchTestCase
1919
from gpytorch import settings as gpytorch_settings
20+
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
2021
from gpytorch.mlls import ExactMarginalLogLikelihood, SumMarginalLogLikelihood
22+
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
23+
from gpytorch.module import Module
24+
from torch import Tensor
2125
from torch.utils.data import DataLoader, TensorDataset
2226

2327

28+
# Mock wrapping the __call__ directly is leading to errors like
29+
# TypeError: super(type, obj): obj must be an instance or subtype of type
30+
# so, doing this manually here.
31+
class WrapperLikelihood(GaussianLikelihood):
32+
def __init__(self, base_likelihood: GaussianLikelihood):
33+
Module.__init__(self)
34+
self.base_likelihood = base_likelihood
35+
self.call_args = []
36+
37+
def __call__(self, *args, **kwargs):
38+
# Store the train inputs arg for testing.
39+
self.call_args.append(args[1])
40+
return self.base_likelihood(*args, **kwargs)
41+
42+
43+
def _get_mlls(
44+
device: torch.device, wrap_likelihood: bool = False
45+
) -> tuple[Tensor, list[MarginalLogLikelihood]]:
46+
"""Returns the train X, along two MLLs: one for a SingleTaskGP and
47+
one for a ModelListGP.
48+
49+
Args:
50+
device: The device to use.
51+
wrap_likelihood: If True, wrap the likelihood in a WrapperLikelihood.
52+
This is useful for comparing call args later.
53+
"""
54+
with torch.random.fork_rng():
55+
torch.manual_seed(0)
56+
# Inputs are not in the unit cube to ensure input transform is applied.
57+
train_X = torch.linspace(0, 5, 10).unsqueeze(-1)
58+
train_Y = torch.sin((2 * pi) * train_X)
59+
train_Y = train_Y + 0.1 * torch.randn_like(train_Y)
60+
mlls = []
61+
model = SingleTaskGP(
62+
train_X=train_X,
63+
train_Y=train_Y,
64+
input_transform=Normalize(d=1),
65+
outcome_transform=Standardize(m=1),
66+
)
67+
if wrap_likelihood:
68+
model.likelihood = WrapperLikelihood(model.likelihood)
69+
mll = ExactMarginalLogLikelihood(model.likelihood, model)
70+
mlls.append(mll.to(device=device, dtype=torch.double))
71+
72+
model = ModelListGP(model, model)
73+
mll = SumMarginalLogLikelihood(model.likelihood, model)
74+
mlls.append(mll.to(device=device, dtype=torch.double))
75+
return train_X.to(device=device, dtype=torch.double), mlls
76+
77+
2478
class TestLossClosures(BotorchTestCase):
25-
def setUp(self):
26-
super().setUp()
27-
with torch.random.fork_rng():
28-
torch.manual_seed(0)
29-
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
30-
train_Y = torch.sin((2 * pi) * train_X)
31-
train_Y = train_Y + 0.1 * torch.randn_like(train_Y)
32-
33-
self.mlls = {}
34-
model = SingleTaskGP(
35-
train_X=train_X,
36-
train_Y=train_Y,
37-
input_transform=Normalize(d=1),
38-
outcome_transform=Standardize(m=1),
39-
)
40-
mll = ExactMarginalLogLikelihood(model.likelihood, model)
41-
self.mlls[type(mll), type(model.likelihood), type(model)] = mll.to(self.device)
42-
43-
model = ModelListGP(model, model)
44-
mll = SumMarginalLogLikelihood(model.likelihood, model)
45-
self.mlls[type(mll), type(model.likelihood), type(model)] = mll.to(self.device)
46-
47-
def test_main(self):
48-
for mll in self.mlls.values():
79+
def test_main(self) -> None:
80+
for mll in _get_mlls(device=self.device)[1]:
4981
out = mll.model(*mll.model.train_inputs)
5082
loss = -mll(out, mll.model.train_targets).sum()
5183
loss.backward()
@@ -63,8 +95,8 @@ def test_main(self):
6395
self.assertTrue(loss.equal(_loss))
6496
self.assertTrue(all(a.equal(b) for a, b in zip_longest(grads, _grads)))
6597

66-
def test_data_loader(self):
67-
for mll in self.mlls.values():
98+
def test_data_loader(self) -> None:
99+
for mll in _get_mlls(device=self.device)[1]:
68100
if type(mll) is not ExactMarginalLogLikelihood:
69101
continue
70102

@@ -86,3 +118,38 @@ def test_data_loader(self):
86118
closure = get_loss_closure_with_grads(mll, params, data_loader=loader)
87119
with self.assertRaisesRegex(TypeError, "Expected .* a batch of tensors"):
88120
closure()
121+
122+
def test_with_input_transforms(self) -> None:
123+
# This test reproduces the bug reported in issue #2515.
124+
train_X, mlls = _get_mlls(device=self.device, wrap_likelihood=True)
125+
for mll in mlls:
126+
if isinstance(mll, SumMarginalLogLikelihood):
127+
# The likelihood is called twice here since it is the same
128+
# likelihood in both child models.
129+
likelihood = mll.model.models[0].likelihood
130+
expected_calls1 = 2 # In the closure call.
131+
expected_calls2 = 6 # Closure + posterior calls.
132+
else:
133+
likelihood = mll.model.likelihood
134+
expected_calls1 = 1 # In the closure call.
135+
expected_calls2 = 4 # Closure + posterior calls.
136+
likelihood.call_args = [] # reset since it is shared between the models.
137+
params = {n: p for n, p in mll.named_parameters() if p.requires_grad}
138+
# Evaluate the closure to mimic the model fitting process.
139+
mll.train()
140+
closure = get_loss_closure_with_grads(mll, params)
141+
closure()
142+
self.assertEqual(len(likelihood.call_args), expected_calls1)
143+
# Call the model posterior to reproduce post-fitting usage.
144+
mll.model.posterior(train_X, observation_noise=True)
145+
# Compare the call args to ensure they're all the same.
146+
# Likelihood is called twice on model(X) and once for adding the noise.
147+
self.assertEqual(len(likelihood.call_args), expected_calls2)
148+
arg0 = likelihood.call_args[0]
149+
for i in range(1, expected_calls2):
150+
argi = likelihood.call_args[i]
151+
# The arg may be a tensor or a single element list of the tensor.
152+
self.assertAllClose(
153+
arg0 if isinstance(arg0, Tensor) else arg0[0],
154+
argi if isinstance(argi, Tensor) else argi[0],
155+
)

0 commit comments

Comments
 (0)