From ea7b1ea2e8f8ed20235b3ecc917153603f995eed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20P=2E=20D=C3=BCrholt?= Date: Thu, 15 Dec 2022 15:30:46 +0100 Subject: [PATCH 1/2] remove input transform checks --- botorch/models/gp_regression_mixed.py | 22 ++++------------- test/models/test_gp_regression_mixed.py | 33 ------------------------- 2 files changed, 5 insertions(+), 50 deletions(-) diff --git a/botorch/models/gp_regression_mixed.py b/botorch/models/gp_regression_mixed.py index 2851a5e57c..668035541f 100644 --- a/botorch/models/gp_regression_mixed.py +++ b/botorch/models/gp_regression_mixed.py @@ -89,29 +89,17 @@ def __init__( `.posterior` on the model will be on the original scale). input_transform: An input transform that is applied in the model's forward pass. Only input transforms are allowed which do not - transform the categorical dimensions. This can be achieved - by using the `indices` argument when constructing the transform. + transform the categorical dimensions. If you want to use it + for example in combination with a `OneHotToNumeric` input transform + one has to instantiate the transform with `transform_on_train` == False + and pass in the already transformed input. """ - if input_transform is not None: - if not hasattr(input_transform, "indices"): - raise ValueError( - "Only continuous inputs can be transformed. " - "Please use `indices` in the `input_transform`." - ) - # check that no cat dim is in indices - elif any(idx in input_transform.indices for idx in cat_dims): - raise ValueError( - "Only continuous inputs can be transformed. " - "Categorical index found in `indices` of the `input_transform`." - ) if len(cat_dims) == 0: raise ValueError( "Must specify categorical dimensions for MixedSingleTaskGP" ) self._ignore_X_dims_scaling_check = cat_dims - input_batch_shape, aug_batch_shape = self.get_batch_dimensions( - train_X=train_X, train_Y=train_Y - ) + _, aug_batch_shape = self.get_batch_dimensions(train_X=train_X, train_Y=train_Y) if cont_kernel_factory is None: diff --git a/test/models/test_gp_regression_mixed.py b/test/models/test_gp_regression_mixed.py index 2044fdf331..ea14f7b960 100644 --- a/test/models/test_gp_regression_mixed.py +++ b/test/models/test_gp_regression_mixed.py @@ -43,15 +43,6 @@ def test_gp(self): ) cat_dims = list(range(ncat)) ord_dims = sorted(set(range(d)) - set(cat_dims)) - with self.assertRaises(ValueError): - MixedSingleTaskGP( - train_X, - train_Y, - cat_dims=cat_dims, - input_transform=Normalize( - d=d, bounds=bounds.to(**tkwargs), transform_on_train=True - ), - ) # test correct indices if (ncat < 3) and (ncat > 0): MixedSingleTaskGP( @@ -65,30 +56,6 @@ def test_gp(self): indices=ord_dims, ), ) - with self.assertRaises(ValueError): - MixedSingleTaskGP( - train_X, - train_Y, - cat_dims=cat_dims, - input_transform=Normalize( - d=d, - bounds=bounds.to(**tkwargs), - transform_on_train=True, - indices=cat_dims, - ), - ) - with self.assertRaises(ValueError): - MixedSingleTaskGP( - train_X, - train_Y, - cat_dims=cat_dims, - input_transform=Normalize( - d=d, - bounds=bounds.to(**tkwargs), - transform_on_train=True, - indices=ord_dims + [random.choice(cat_dims)], - ), - ) if len(cat_dims) == 0: with self.assertRaises(ValueError): From ca0ed76470e385c5935980d2862142cc70ebf647 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20P=2E=20D=C3=BCrholt?= Date: Thu, 15 Dec 2022 15:51:13 +0100 Subject: [PATCH 2/2] fix flake8 --- test/models/test_gp_regression_mixed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/models/test_gp_regression_mixed.py b/test/models/test_gp_regression_mixed.py index ea14f7b960..c6c38fd008 100644 --- a/test/models/test_gp_regression_mixed.py +++ b/test/models/test_gp_regression_mixed.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import itertools -import random import warnings import torch