Skip to content

remove input transform checks #1568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 5 additions & 17 deletions botorch/models/gp_regression_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,29 +89,17 @@ def __init__(
`.posterior` on the model will be on the original scale).
input_transform: An input transform that is applied in the model's
forward pass. Only input transforms are allowed which do not
transform the categorical dimensions. This can be achieved
by using the `indices` argument when constructing the transform.
transform the categorical dimensions. If you want to use it
for example in combination with a `OneHotToNumeric` input transform
one has to instantiate the transform with `transform_on_train` == False
and pass in the already transformed input.
"""
if input_transform is not None:
if not hasattr(input_transform, "indices"):
raise ValueError(
"Only continuous inputs can be transformed. "
"Please use `indices` in the `input_transform`."
)
# check that no cat dim is in indices
elif any(idx in input_transform.indices for idx in cat_dims):
raise ValueError(
"Only continuous inputs can be transformed. "
"Categorical index found in `indices` of the `input_transform`."
)
if len(cat_dims) == 0:
raise ValueError(
"Must specify categorical dimensions for MixedSingleTaskGP"
)
self._ignore_X_dims_scaling_check = cat_dims
input_batch_shape, aug_batch_shape = self.get_batch_dimensions(
train_X=train_X, train_Y=train_Y
)
_, aug_batch_shape = self.get_batch_dimensions(train_X=train_X, train_Y=train_Y)

if cont_kernel_factory is None:

Expand Down
34 changes: 0 additions & 34 deletions test/models/test_gp_regression_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import itertools
import random
import warnings

import torch
Expand Down Expand Up @@ -44,15 +43,6 @@ def test_gp(self):
)
cat_dims = list(range(ncat))
ord_dims = sorted(set(range(d)) - set(cat_dims))
with self.assertRaises(ValueError):
MixedSingleTaskGP(
train_X,
train_Y,
cat_dims=cat_dims,
input_transform=Normalize(
d=d, bounds=bounds.to(**tkwargs), transform_on_train=True
),
)
# test correct indices
if (ncat < 3) and (ncat > 0):
MixedSingleTaskGP(
Expand All @@ -66,30 +56,6 @@ def test_gp(self):
indices=ord_dims,
),
)
with self.assertRaises(ValueError):
MixedSingleTaskGP(
train_X,
train_Y,
cat_dims=cat_dims,
input_transform=Normalize(
d=d,
bounds=bounds.to(**tkwargs),
transform_on_train=True,
indices=cat_dims,
),
)
with self.assertRaises(ValueError):
MixedSingleTaskGP(
train_X,
train_Y,
cat_dims=cat_dims,
input_transform=Normalize(
d=d,
bounds=bounds.to(**tkwargs),
transform_on_train=True,
indices=ord_dims + [random.choice(cat_dims)],
),
)

if len(cat_dims) == 0:
with self.assertRaises(ValueError):
Expand Down