Skip to content

Commit 52676e9

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add utility for constructing rounding input transforms (#1531)
Summary: Pull Request resolved: #1531 see title Differential Revision: https://internalfb.com/D41497584 fbshipit-source-id: 208098b14ccb3030e6d0e7185d05e0e60adeb0b3
1 parent a5b8efc commit 52676e9

File tree

6 files changed

+301
-3
lines changed

6 files changed

+301
-3
lines changed

botorch/models/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from botorch.models.transforms.factory import get_rounding_input_transform
78
from botorch.models.transforms.input import (
89
ChainedInputTransform,
910
Normalize,
@@ -20,6 +21,7 @@
2021

2122

2223
__all__ = [
24+
"get_rounding_input_transform",
2325
"Bilog",
2426
"ChainedInputTransform",
2527
"ChainedOutcomeTransform",

botorch/models/transforms/factory.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from collections import OrderedDict
10+
from typing import Dict, List, Optional
11+
12+
from botorch.models.transforms.input import (
13+
ChainedInputTransform,
14+
Normalize,
15+
OneHotToNumeric,
16+
Round,
17+
)
18+
from torch import Tensor
19+
20+
21+
def get_rounding_input_transform(
22+
one_hot_bounds: Tensor,
23+
integer_indices: Optional[List[int]] = None,
24+
categorical_features: Optional[Dict[int, int]] = None,
25+
initialization: bool = False,
26+
return_numeric: bool = False,
27+
approximate: bool = False,
28+
) -> ChainedInputTransform:
29+
"""Get a rounding input transform.
30+
31+
The rounding function will take inputs from the unit cube,
32+
unnormalize the integers raw search space, round the inputs,
33+
and normalize them back to the unit cube.
34+
35+
Categoricals are assumed to be one-hot encoded. Integers are
36+
currently assumed to be contiguous ranges (e.g. [1,2,3] and not
37+
[1,5,7]).
38+
39+
TODO: support non-contiguous sets of integers by modifying
40+
the rounding function.
41+
42+
Args:
43+
one_hot_bounds: The raw search space bounds where categoricals are
44+
encoded in one-hot representation and the integer parameters
45+
are not normalized.
46+
integer_indices: The indices of the integer parameters.
47+
categorical_features: A dictionary mapping indices to cardinalities
48+
for the categorical features.
49+
initialization: A boolean indicating whether this exact rounding
50+
function is for initialization. For initialization, the bounds
51+
for are expanded such that the end point of a range is selected
52+
with same probability that an interior point is selected, after
53+
rounding.
54+
return_numeric: A boolean indicating whether to return numeric or
55+
one-hot encoded categoricals. Returning a nummeric
56+
representation is helpful if the downstream code (e.g. kernel)
57+
expects a numeric representation of the categoricals.
58+
approximate: A boolean indicating whether to use an approximate
59+
rounding function.
60+
61+
Returns:
62+
The rounding function ChainedInputTransform.
63+
"""
64+
has_integers = integer_indices is not None and len(integer_indices) > 0
65+
has_categoricals = (
66+
categorical_features is not None and len(categorical_features) > 0
67+
)
68+
if not (has_integers or has_categoricals):
69+
raise ValueError(
70+
"A rounding function is a no-op "
71+
"if there are no integer or categorical parammeters."
72+
)
73+
if initialization and has_integers:
74+
# this gives the extreme integer values (end points)
75+
# the same probability as the interior values of the range
76+
init_one_hot_bounds = one_hot_bounds.clone()
77+
init_one_hot_bounds[0, integer_indices] -= 0.4999
78+
init_one_hot_bounds[1, integer_indices] += 0.4999
79+
else:
80+
init_one_hot_bounds = one_hot_bounds
81+
82+
tfs = OrderedDict()
83+
if has_integers:
84+
# unnormalize to integer space
85+
tfs["unnormalize_tf"] = Normalize(
86+
d=init_one_hot_bounds.shape[1],
87+
bounds=init_one_hot_bounds,
88+
indices=integer_indices,
89+
transform_on_train=False,
90+
transform_on_eval=True,
91+
transform_on_fantasize=True,
92+
reverse=True,
93+
)
94+
# round
95+
tfs["round"] = Round(
96+
approximate=approximate,
97+
transform_on_train=False,
98+
transform_on_fantasize=True,
99+
integer_indices=integer_indices,
100+
categorical_features=categorical_features,
101+
)
102+
if has_integers:
103+
# renormalize to unit cube
104+
tfs["normalize_tf"] = Normalize(
105+
d=one_hot_bounds.shape[1],
106+
bounds=one_hot_bounds,
107+
indices=integer_indices,
108+
transform_on_train=False,
109+
transform_on_eval=True,
110+
transform_on_fantasize=True,
111+
reverse=False,
112+
)
113+
if return_numeric and has_categoricals:
114+
tfs["one_hot_to_numeric"] = OneHotToNumeric(
115+
# this is the dimension using one-hot encoded representation
116+
dim=one_hot_bounds.shape[-1],
117+
categorical_features=categorical_features,
118+
transform_on_train=True,
119+
transform_on_eval=True,
120+
transform_on_fantasize=True,
121+
)
122+
tf = ChainedInputTransform(**tfs)
123+
tf.to(dtype=one_hot_bounds.dtype, device=one_hot_bounds.device)
124+
tf.eval()
125+
return tf

botorch/models/transforms/input.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,9 +1383,9 @@ def __init__(
13831383
self,
13841384
dim: int,
13851385
categorical_features: Optional[Dict[int, int]] = None,
1386-
transform_on_train: bool = False,
1386+
transform_on_train: bool = True,
13871387
transform_on_eval: bool = True,
1388-
transform_on_fantasize: bool = False,
1388+
transform_on_fantasize: bool = True,
13891389
) -> None:
13901390
r"""Initialize.
13911391

botorch/models/transforms/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from __future__ import annotations
88

99
from functools import wraps
10-
1110
from typing import Tuple
1211

1312
import torch

sphinx/source/models.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ Input Transforms
136136
.. automodule:: botorch.models.transforms.input
137137
:members:
138138

139+
Transform Factory Methods
140+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
141+
.. automodule:: botorch.models.transforms.factory
142+
:members:
143+
139144
Transform Utilities
140145
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
141146
.. automodule:: botorch.models.transforms.utils
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
from botorch.models.transforms.factory import get_rounding_input_transform
10+
from botorch.models.transforms.input import ChainedInputTransform, Normalize, Round
11+
from botorch.utils.rounding import OneHotArgmaxSTE
12+
from botorch.utils.testing import BotorchTestCase
13+
from botorch.utils.transforms import normalize, unnormalize
14+
15+
16+
class TestGetRoundingInputTransform(BotorchTestCase):
17+
def test_get_rounding_input_transform(self):
18+
for dtype in (torch.float, torch.double):
19+
one_hot_bounds = torch.tensor(
20+
[
21+
[0, 5],
22+
[0, 4],
23+
[0, 1],
24+
[0, 1],
25+
[0, 1],
26+
[0, 1],
27+
[0, 1],
28+
],
29+
dtype=dtype,
30+
device=self.device,
31+
).t()
32+
with self.assertRaises(ValueError):
33+
# test no integer or categorical
34+
get_rounding_input_transform(
35+
one_hot_bounds=one_hot_bounds,
36+
)
37+
integer_indices = [1]
38+
categorical_features = {2: 2, 4: 3}
39+
tf = get_rounding_input_transform(
40+
one_hot_bounds=one_hot_bounds,
41+
integer_indices=integer_indices,
42+
categorical_features=categorical_features,
43+
)
44+
self.assertIsInstance(tf, ChainedInputTransform)
45+
tfs = list(tf.items())
46+
self.assertEqual(len(tfs), 3)
47+
# test unnormalize
48+
tf_name_i, tf_i = tfs[0]
49+
self.assertEqual(tf_name_i, "unnormalize_tf")
50+
self.assertIsInstance(tf_i, Normalize)
51+
self.assertTrue(tf_i.reverse)
52+
bounds = one_hot_bounds[:, integer_indices]
53+
offset = bounds[:1, :]
54+
coefficient = bounds[1:2, :] - offset
55+
self.assertTrue(torch.equal(tf_i.coefficient, coefficient))
56+
self.assertTrue(torch.equal(tf_i.offset, offset))
57+
self.assertEqual(tf_i._d, one_hot_bounds.shape[1])
58+
self.assertEqual(
59+
tf_i.indices, torch.tensor(integer_indices, device=self.device)
60+
)
61+
# test round
62+
tf_name_i, tf_i = tfs[1]
63+
self.assertEqual(tf_name_i, "round")
64+
self.assertIsInstance(tf_i, Round)
65+
self.assertEqual(tf_i.integer_indices.tolist(), integer_indices)
66+
self.assertEqual(tf_i.categorical_features, categorical_features)
67+
# test normalize
68+
tf_name_i, tf_i = tfs[2]
69+
self.assertEqual(tf_name_i, "normalize_tf")
70+
self.assertIsInstance(tf_i, Normalize)
71+
self.assertFalse(tf_i.reverse)
72+
self.assertTrue(torch.equal(tf_i.coefficient, coefficient))
73+
self.assertTrue(torch.equal(tf_i.offset, offset))
74+
self.assertEqual(tf_i._d, one_hot_bounds.shape[1])
75+
76+
# test forward
77+
X = torch.rand(
78+
2, 4, one_hot_bounds.shape[1], dtype=dtype, device=self.device
79+
)
80+
X_tf = tf(X)
81+
# assert the continuous param is unaffected
82+
self.assertTrue(torch.equal(X_tf[..., 0], X[..., 0]))
83+
# check that integer params are rounded
84+
X_int = X[..., integer_indices]
85+
unnormalized_X_int = unnormalize(X_int, bounds)
86+
rounded_X_int = normalize(unnormalized_X_int.round(), bounds)
87+
self.assertTrue(torch.equal(rounded_X_int, X_tf[..., integer_indices]))
88+
# check that categoricals are discretized
89+
for start, card in categorical_features.items():
90+
end = start + card
91+
discretized_feat = OneHotArgmaxSTE.apply(X[..., start:end])
92+
self.assertTrue(torch.equal(discretized_feat, X_tf[..., start:end]))
93+
# test transform on train/eval/fantasize
94+
for tf_i in tf.values():
95+
self.assertFalse(tf_i.transform_on_train)
96+
self.assertTrue(tf_i.transform_on_eval)
97+
self.assertTrue(tf_i.transform_on_fantasize)
98+
99+
# test no integer
100+
tf = get_rounding_input_transform(
101+
one_hot_bounds=one_hot_bounds,
102+
categorical_features=categorical_features,
103+
)
104+
tfs = list(tf.items())
105+
# round should be the only transform
106+
self.assertEqual(len(tfs), 1)
107+
tf_name_i, tf_i = tfs[0]
108+
self.assertEqual(tf_name_i, "round")
109+
self.assertIsInstance(tf_i, Round)
110+
self.assertEqual(tf_i.integer_indices.tolist(), [])
111+
self.assertEqual(tf_i.categorical_features, categorical_features)
112+
# test no categoricals
113+
tf = get_rounding_input_transform(
114+
one_hot_bounds=one_hot_bounds,
115+
integer_indices=integer_indices,
116+
)
117+
tfs = list(tf.items())
118+
self.assertEqual(len(tfs), 3)
119+
_, tf_i = tfs[1]
120+
self.assertEqual(tf_i.integer_indices.tolist(), integer_indices)
121+
self.assertEqual(tf_i.categorical_features, {})
122+
# test initialization
123+
tf = get_rounding_input_transform(
124+
one_hot_bounds=one_hot_bounds,
125+
integer_indices=integer_indices,
126+
categorical_features=categorical_features,
127+
initialization=True,
128+
)
129+
tfs = list(tf.items())
130+
self.assertEqual(len(tfs), 3)
131+
# check that bounds are adjusted for integers on unnormalize
132+
_, tf_i = tfs[0]
133+
offset_init = bounds[:1, :] - 0.4999
134+
coefficient_init = bounds[1:2, :] + 0.4999 - offset_init
135+
self.assertTrue(torch.equal(tf_i.coefficient, coefficient_init))
136+
self.assertTrue(torch.equal(tf_i.offset, offset_init))
137+
# check that bounds are adjusted for integers on normalize
138+
_, tf_i = tfs[2]
139+
self.assertTrue(torch.equal(tf_i.coefficient, coefficient))
140+
self.assertTrue(torch.equal(tf_i.offset, offset))
141+
# test return numeric
142+
tf = get_rounding_input_transform(
143+
one_hot_bounds=one_hot_bounds,
144+
integer_indices=integer_indices,
145+
categorical_features=categorical_features,
146+
return_numeric=True,
147+
)
148+
tfs = list(tf.items())
149+
self.assertEqual(len(tfs), 4)
150+
tf_name_i, tf_i = tfs[3]
151+
self.assertEqual(tf_name_i, "one_hot_to_numeric")
152+
# transform to numeric on train
153+
# (e.g. for kernels that expect a integer representation)
154+
self.assertTrue(tf_i.transform_on_train)
155+
self.assertTrue(tf_i.transform_on_eval)
156+
self.assertTrue(tf_i.transform_on_fantasize)
157+
self.assertEqual(tf_i.categorical_features, categorical_features)
158+
self.assertEqual(tf_i.numeric_dim, 4)
159+
# test return numeric and no categorical
160+
tf = get_rounding_input_transform(
161+
one_hot_bounds=one_hot_bounds,
162+
integer_indices=integer_indices,
163+
return_numeric=True,
164+
)
165+
tfs = list(tf.items())
166+
# there should be no one hot to numeric transform
167+
self.assertEqual(len(tfs), 3)

0 commit comments

Comments
 (0)