17
17
)
18
18
from botorch .utils .testing import BotorchTestCase
19
19
from gpytorch import settings as gpytorch_settings
20
+ from gpytorch .likelihoods .gaussian_likelihood import GaussianLikelihood
20
21
from gpytorch .mlls import ExactMarginalLogLikelihood , SumMarginalLogLikelihood
22
+ from gpytorch .mlls .marginal_log_likelihood import MarginalLogLikelihood
23
+ from gpytorch .module import Module
24
+ from torch import Tensor
21
25
from torch .utils .data import DataLoader , TensorDataset
22
26
23
27
28
+ # Mock wrapping the __call__ directly is leading to errors like
29
+ # TypeError: super(type, obj): obj must be an instance or subtype of type
30
+ # so, doing this manually here.
31
+ class WrapperLikelihood (GaussianLikelihood ):
32
+ def __init__ (self , base_likelihood : GaussianLikelihood ):
33
+ Module .__init__ (self )
34
+ self .base_likelihood = base_likelihood
35
+ self .call_args = []
36
+
37
+ def __call__ (self , * args , ** kwargs ):
38
+ # Store the train inputs arg for testing.
39
+ self .call_args .append (args [1 ])
40
+ return self .base_likelihood (* args , ** kwargs )
41
+
42
+
43
+ def _get_mlls (
44
+ device : torch .device , wrap_likelihood : bool = False
45
+ ) -> tuple [Tensor , list [MarginalLogLikelihood ]]:
46
+ """Returns the train X, along two MLLs: one for a SingleTaskGP and
47
+ one for a ModelListGP.
48
+
49
+ Args:
50
+ device: The device to use.
51
+ wrap_likelihood: If True, wrap the likelihood in a WrapperLikelihood.
52
+ This is useful for comparing call args later.
53
+ """
54
+ with torch .random .fork_rng ():
55
+ torch .manual_seed (0 )
56
+ # Inputs are not in the unit cube to ensure input transform is applied.
57
+ train_X = torch .linspace (0 , 5 , 10 ).unsqueeze (- 1 )
58
+ train_Y = torch .sin ((2 * pi ) * train_X )
59
+ train_Y = train_Y + 0.1 * torch .randn_like (train_Y )
60
+ mlls = []
61
+ model = SingleTaskGP (
62
+ train_X = train_X ,
63
+ train_Y = train_Y ,
64
+ input_transform = Normalize (d = 1 ),
65
+ outcome_transform = Standardize (m = 1 ),
66
+ )
67
+ if wrap_likelihood :
68
+ model .likelihood = WrapperLikelihood (model .likelihood )
69
+ mll = ExactMarginalLogLikelihood (model .likelihood , model )
70
+ mlls .append (mll .to (device = device , dtype = torch .double ))
71
+
72
+ model = ModelListGP (model , model )
73
+ mll = SumMarginalLogLikelihood (model .likelihood , model )
74
+ mlls .append (mll .to (device = device , dtype = torch .double ))
75
+ return train_X .to (device = device , dtype = torch .double ), mlls
76
+
77
+
24
78
class TestLossClosures (BotorchTestCase ):
25
- def setUp (self ):
26
- super ().setUp ()
27
- with torch .random .fork_rng ():
28
- torch .manual_seed (0 )
29
- train_X = torch .linspace (0 , 1 , 10 ).unsqueeze (- 1 )
30
- train_Y = torch .sin ((2 * pi ) * train_X )
31
- train_Y = train_Y + 0.1 * torch .randn_like (train_Y )
32
-
33
- self .mlls = {}
34
- model = SingleTaskGP (
35
- train_X = train_X ,
36
- train_Y = train_Y ,
37
- input_transform = Normalize (d = 1 ),
38
- outcome_transform = Standardize (m = 1 ),
39
- )
40
- mll = ExactMarginalLogLikelihood (model .likelihood , model )
41
- self .mlls [type (mll ), type (model .likelihood ), type (model )] = mll .to (self .device )
42
-
43
- model = ModelListGP (model , model )
44
- mll = SumMarginalLogLikelihood (model .likelihood , model )
45
- self .mlls [type (mll ), type (model .likelihood ), type (model )] = mll .to (self .device )
46
-
47
- def test_main (self ):
48
- for mll in self .mlls .values ():
79
+ def test_main (self ) -> None :
80
+ for mll in _get_mlls (device = self .device )[1 ]:
49
81
out = mll .model (* mll .model .train_inputs )
50
82
loss = - mll (out , mll .model .train_targets ).sum ()
51
83
loss .backward ()
@@ -63,8 +95,8 @@ def test_main(self):
63
95
self .assertTrue (loss .equal (_loss ))
64
96
self .assertTrue (all (a .equal (b ) for a , b in zip_longest (grads , _grads )))
65
97
66
- def test_data_loader (self ):
67
- for mll in self .mlls . values () :
98
+ def test_data_loader (self ) -> None :
99
+ for mll in _get_mlls ( device = self .device )[ 1 ] :
68
100
if type (mll ) is not ExactMarginalLogLikelihood :
69
101
continue
70
102
@@ -86,3 +118,38 @@ def test_data_loader(self):
86
118
closure = get_loss_closure_with_grads (mll , params , data_loader = loader )
87
119
with self .assertRaisesRegex (TypeError , "Expected .* a batch of tensors" ):
88
120
closure ()
121
+
122
+ def test_with_input_transforms (self ) -> None :
123
+ # This test reproduces the bug reported in issue #2515.
124
+ train_X , mlls = _get_mlls (device = self .device , wrap_likelihood = True )
125
+ for mll in mlls :
126
+ if isinstance (mll , SumMarginalLogLikelihood ):
127
+ # The likelihood is called twice here since it is the same
128
+ # likelihood in both child models.
129
+ likelihood = mll .model .models [0 ].likelihood
130
+ expected_calls1 = 2 # In the closure call.
131
+ expected_calls2 = 6 # Closure + posterior calls.
132
+ else :
133
+ likelihood = mll .model .likelihood
134
+ expected_calls1 = 1 # In the closure call.
135
+ expected_calls2 = 4 # Closure + posterior calls.
136
+ likelihood .call_args = [] # reset since it is shared between the models.
137
+ params = {n : p for n , p in mll .named_parameters () if p .requires_grad }
138
+ # Evaluate the closure to mimic the model fitting process.
139
+ mll .train ()
140
+ closure = get_loss_closure_with_grads (mll , params )
141
+ closure ()
142
+ self .assertEqual (len (likelihood .call_args ), expected_calls1 )
143
+ # Call the model posterior to reproduce post-fitting usage.
144
+ mll .model .posterior (train_X , observation_noise = True )
145
+ # Compare the call args to ensure they're all the same.
146
+ # Likelihood is called twice on model(X) and once for adding the noise.
147
+ self .assertEqual (len (likelihood .call_args ), expected_calls2 )
148
+ arg0 = likelihood .call_args [0 ]
149
+ for i in range (1 , expected_calls2 ):
150
+ argi = likelihood .call_args [i ]
151
+ # The arg may be a tensor or a single element list of the tensor.
152
+ self .assertAllClose (
153
+ arg0 if isinstance (arg0 , Tensor ) else arg0 [0 ],
154
+ argi if isinstance (argi , Tensor ) else argi [0 ],
155
+ )
0 commit comments