Skip to content

Commit 5333d5b

Browse files
committed
Use gpytorch constraints for bounds on parameters during model fitting.
Summary: gpytorch now supports defining constraints on its submodules. This allows to specify parameter constraints where they belong (on the model), and have model fitting deal with this in a generic way. Note that constraints that have a transform that is not `None` automatically enforces the constraint by using a transform. This can be an issue for quasi 2nd order optimizers though b/c the objective becomes flat when overshooting past the effective constraint in the line search. Hence not doing the transform and imposing an explicit constraint is preferred. It may also be beneficial to use the transform in conjunction with an explicit bound - will have to evaluate that more. Reviewed By: bletham Differential Revision: D14840983 fbshipit-source-id: 6f52ec9eb0b970a692963083125e58df55a46de5
1 parent 44f18fd commit 5333d5b

File tree

5 files changed

+255
-115
lines changed

5 files changed

+255
-115
lines changed

botorch/models/gp_regression.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Optional
99

1010
import torch
11+
from gpytorch.constraints.constraints import GreaterThan
1112
from gpytorch.distributions.multivariate_normal import MultivariateNormal
1213
from gpytorch.kernels.matern_kernel import MaternKernel
1314
from gpytorch.kernels.scale_kernel import ScaleKernel
@@ -27,6 +28,9 @@
2728
from .gpytorch import GPyTorchModel
2829

2930

31+
MIN_INFERRED_NOISE_LEVEL = 1e-6
32+
33+
3034
class SingleTaskGP(ExactGP, GPyTorchModel):
3135
r"""A single-task Exact GP model.
3236
@@ -58,10 +62,10 @@ def __init__(
5862
raise ValueError(f"Unsupported shape {train_X.shape} for train_X.")
5963
if likelihood is None:
6064
likelihood = GaussianLikelihood(
61-
noise_prior=GammaPrior(1.1, 0.05), batch_size=batch_size
65+
noise_prior=GammaPrior(1.1, 0.05),
66+
batch_size=batch_size,
67+
noise_constraint=GreaterThan(MIN_INFERRED_NOISE_LEVEL, transform=None),
6268
)
63-
# TODO: Use gpytorch constraints
64-
likelihood.parameter_bounds = {"noise_covar.raw_noise": (-15, None)}
6569
else:
6670
self._likelihood_state_dict = deepcopy(likelihood.state_dict())
6771
super().__init__(train_X, train_Y, likelihood)
@@ -183,15 +187,13 @@ class HeteroskedasticSingleTaskGP(SingleTaskGP):
183187
def __init__(self, train_X: Tensor, train_Y: Tensor, train_Y_se: Tensor) -> None:
184188
train_Y_log_var = 2 * torch.log(train_Y_se)
185189
noise_likelihood = GaussianLikelihood(
186-
noise_prior=SmoothedBoxPrior(-3, 5, 0.5, transform=torch.log)
190+
noise_prior=SmoothedBoxPrior(-3, 5, 0.5, transform=torch.log),
191+
noise_constraint=GreaterThan(MIN_INFERRED_NOISE_LEVEL, transform=None),
187192
)
188193
noise_model = SingleTaskGP(
189194
train_X=train_X, train_Y=train_Y_log_var, likelihood=noise_likelihood
190195
)
191196
likelihood = _GaussianLikelihoodBase(HeteroskedasticNoise(noise_model))
192-
likelihood.parameter_bounds = {
193-
"noise_covar.noise_model.likelihood.noise_covar.raw_noise": (-15, None)
194-
}
195197
super().__init__(train_X=train_X, train_Y=train_Y, likelihood=likelihood)
196198
self.to(train_X)
197199

botorch/optim/fit.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class OptimizationIteration(NamedTuple):
2626

2727
def fit_gpytorch_torch(
2828
mll: MarginalLogLikelihood,
29+
bounds: Optional[ParameterBounds] = None,
2930
optimizer_cls: Optimizer = Adam,
3031
lr: float = 0.05,
3132
maxiter: int = 100,
@@ -41,6 +42,10 @@ def fit_gpytorch_torch(
4142
4243
Args:
4344
mll: MarginalLogLikelihood to be maximized.
45+
bounds: A ParameterBounds dictionary mapping parameter names to tuples
46+
of lower and upper bounds. Bounds specified here take precedence
47+
over bounds on the same parameters specified in the constraints
48+
registered with the module.
4449
optimizer_cls: Torch optimizer to use. Must not need a closure.
4550
Defaults to Adam.
4651
lr: Starting learning rate.
@@ -60,6 +65,17 @@ def fit_gpytorch_torch(
6065
params=[{"params": mll.parameters()}], lr=lr, **optimizer_args
6166
)
6267

68+
# get bounds specified in model (if any)
69+
bounds_: ParameterBounds = {}
70+
if hasattr(mll, "named_parameters_and_constraints"):
71+
for param_name, _, constraint in mll.named_parameters_and_constraints():
72+
if constraint is not None and not constraint.enforced:
73+
bounds_[param_name] = constraint.lower_bound, constraint.upper_bound
74+
75+
# update with user-supplied bounds (overwrites if already exists)
76+
if bounds is not None:
77+
bounds_.update(bounds)
78+
6379
iterations = []
6480
t1 = time.time()
6581

@@ -80,11 +96,16 @@ def fit_gpytorch_torch(
8096
loss_trajectory.append(loss.item())
8197
for name, param in mll.named_parameters():
8298
param_trajectory[name].append(param.detach().clone())
83-
if disp and (i % 10 == 0 or i == (maxiter - 1)):
84-
print(f"Iter {i +1}/{maxiter}: {loss.item()}")
99+
if disp and ((i + 1) % 10 == 0 or i == (maxiter - 1)):
100+
print(f"Iter {i + 1}/{maxiter}: {loss.item()}")
85101
if track_iterations:
86102
iterations.append(OptimizationIteration(i, loss.item(), time.time() - t1))
87103
optimizer.step()
104+
# project onto bounds:
105+
if bounds_:
106+
for pname, param in mll.named_parameters():
107+
if pname in bounds_:
108+
param.data = param.data.clamp(*bounds_[pname])
88109
i += 1
89110
converged = check_convergence(
90111
loss_trajectory=loss_trajectory,

botorch/optim/numpy_converter.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections import OrderedDict
44
from math import inf
5-
from typing import Dict, List, NamedTuple, Optional, Tuple
5+
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
66

77
import numpy as np
88
import torch
@@ -19,18 +19,22 @@ class TorchAttr(NamedTuple):
1919

2020

2121
def module_to_array(
22-
module: Module, bounds: Optional[ParameterBounds] = None
22+
module: Module,
23+
bounds: Optional[ParameterBounds] = None,
24+
exclude: Optional[Set[str]] = None,
2325
) -> Tuple[np.ndarray, Dict[str, TorchAttr], Optional[np.ndarray]]:
2426
r"""Extract named parameters from a module into a numpy array.
2527
2628
Only extracts parameters with requires_grad, since it is meant for optimizing.
2729
2830
Args:
2931
module: A module with parameters. May specify parameter constraints in
30-
a `parameter_bounds` attribute.
31-
bounds: A ParameterBounds dictionary mapping parameter names to tuples of
32-
lower and upper bounds. Bounds specified here take precedence over
33-
bounds specified in the `parameter_bounds` attribute of the module.
32+
a `named_parameters_and_constraints` method.
33+
bounds: A ParameterBounds dictionary mapping parameter names to tuples
34+
of lower and upper bounds. Bounds specified here take precedence
35+
over bounds on the same parameters specified in the constraints
36+
registered with the module.
37+
exclude: A list of parameter names that are to be excluded from extraction.
3438
3539
Returns:
3640
np.ndarray: The parameter values
@@ -43,23 +47,21 @@ def module_to_array(
4347
lower: List[np.ndarray] = []
4448
upper: List[np.ndarray] = []
4549
property_dict = OrderedDict()
50+
exclude = set() if exclude is None else exclude
4651

47-
# extract parameter bounds from module.model.parameter_bounds and
48-
# module.likelihood.parameter_bounds (if present)
49-
model_bounds = getattr(getattr(module, "model", None), "parameter_bounds", {})
50-
bounds_ = {".".join(["model", key]): val for key, val in model_bounds.items()}
51-
likelihood_bounds = getattr(
52-
getattr(module, "likelihood", None), "parameter_bounds", {}
53-
)
54-
bounds_.update(
55-
{".".join(["likelihood", key]): val for key, val in likelihood_bounds.items()}
56-
)
57-
# update with user-supplied bounds
52+
# get bounds specified in model (if any)
53+
bounds_: ParameterBounds = {}
54+
if hasattr(module, "named_parameters_and_constraints"):
55+
for param_name, _, constraint in module.named_parameters_and_constraints():
56+
if constraint is not None and not constraint.enforced:
57+
bounds_[param_name] = constraint.lower_bound, constraint.upper_bound
58+
59+
# update with user-supplied bounds (overwrites if already exists)
5860
if bounds is not None:
5961
bounds_.update(bounds)
6062

6163
for p_name, t in module.named_parameters():
62-
if t.requires_grad:
64+
if p_name not in exclude and t.requires_grad:
6365
property_dict[p_name] = TorchAttr(
6466
shape=t.shape, dtype=t.dtype, device=t.device
6567
)

0 commit comments

Comments
 (0)